I have finetuned a T5 model for generation, and it learns well and the loss decreases a lot. When I evaluate the model, I’m happy with the results when I do model.forward
, but when I call model.generate
, no matter what parameters I set for decoding, the model performs poorly. In fact, once the model has overfitted, model.generate
something pretty similar every time. I’m unsure of what I’m doing wrong, and I’ve looked at lots of topics. It seems like right now I don’t even have to shift my inputs right for the decoder as long as I’m not passing decoder inputs, so I’m very unsure what the problem is.
Also, how do people normally decide what decoding method to use? Is it just based on what “looks good” on the validation set?
Here is my tokenizing code:
encoding = tokenizer(
[task_prefix + sequence for sequence in input_sequences],
target_tensor_input, t5_masks = encoding.input_ids, encoding.attention_mask
target_encoding = tokenizer(
output_sequences, padding="longest", max_length=max_target_length, truncation=True
labels = target_encoding.input_ids
# replace padding token id's of the labels by -100 so it's ignored by the loss
t5_outputs = torch.tensor(labels)
t5_outputs [t5_outputs == tokenizer.pad_token_id] = -100
Here is a minimalist version of my model training code:
for target_tensor_input, t5_outputs, t5_masks in tqdm(train_dataloader):
loss = model(input_ids=target_tensor_input.cuda(), attention_mask=t5_masks.cuda(), labels=t5_outputs.cuda()).loss
My evaluation code is similar:
with torch.no_grad():
for target_tensor_input, t5_outputs, t5_masks in tqdm(train_dataloader):
# the output here varies
batch_outputs_generate = model.generate(input_ids=target_tensor_input.cuda(), min_length=2, max_length=max_target_length, do_sample=True, num_beams=8)
batch_outputs_generate = tokenizer.batch_decode(batch_outputs_generate, skip_special_tokens=True)
# the output here is good
batch_outputs_forward = model(input_ids=target_tensor_input.cuda(), attention_mask=t5_masks.cuda(), labels=t5_outputs.cuda())
batch_outputs_forward_output = tokenizer.batch_decode(torch.argmax(batch_outputs_forward.logits, dim=2).tolist(), skip_special_tokens=True)