I am using T5 model for a seq2seq task. I ensured to replace padding tokens with -100 for labels. The below is my tokenizer configuration
max_source_length = 90
max_target_length = 90
def tokenization_function(batch):
model_inputs = tokenizer(batch['user_request'], padding="max_length", max_length=max_source_length, truncation=True, return_tensors="pt")
labels = tokenizer(batch['command'], padding="max_length", max_length=max_target_length, truncation=True, return_tensors="pt")
model_inputs["decoder_attention_mask"] = labels['attention_mask']
labels = labels["input_ids"]
labels[labels == tokenizer.pad_token_id] = -100
model_inputs["labels"] = labels
return model_inputs
tokenized_dataset = dataset.map(tokenization_function, batched=True, batch_size=1024)
tokenized_dataset
After training, I do inference using the below script
with torch.no_grad():
for iter, batch in enumerate(eval_dataloader):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
input_ids = input_ids.to(device); attention_mask = attention_mask.to(device); labels = labels.to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
pred = torch.argmax(outputs['logits'], axis=-1)
for i, p in enumerate(pred):
if torch.where(p==1)[0].size(0) != 0:
idx = torch.where(p==1)[0][0]
seq = p[:idx].reshape(1,-1)
else:
seq = p.reshape(1,-1)
pred_text = tokenizer.batch_decode(seq)
print(batch['command'][i])
print(pred_text[0])
print()
break
for instance, pred[0] has the below value after applying argmax
tensor([ 1041, 834, 6583, 283, 26479, 3876, 834, 6583, 3, 4254,
25528, 16, 10646, 834, 5540, 5839, 804, 834, 5540, 15959,
3856, 15, 44, 15959, 3138, 1, 1, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041,
1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041, 1041],
device='cuda:0')
shouldnât the auto regression of decoder stop after predicting id: 1, because with my limited knowledge I beleive 1 corresponds to end of token â< / s>â. But instead why am I getting 1041 till max length is reached, i.e 90. Is it an desirec output? What should I do to stop my prediction right after token is predicted?
I am a beginner in working on language models, so please feel free to pin point any other issues in the snippets
cc: @nielsr