Because I want to do graph-to-text, so I train my own embedding as input to generate text.
But when I use .generate() it return error.
Follow is my code, the plm is BART model
plm = BartForConditionalGeneration.from_pretrained(f"{config['data_dir']}/{config['data_type']}/module/{epoch_idx}")
nodes = nodes.to(device)
student_embeddings = student(nodes, edges, types)
node_masks = node_masks.to(device)
generated_ids = plm.generate(input_ids=None, inputs_embeds=student_embeddings, decoder_input_ids=(torch.zeros(0)).to(device), attention_mask=node_masks, num_beams=4, max_length=config["max_seq_length"], early_stopping=True)