BART learns well, loss decreases, but prediction output is weird

I am trying to fine tune BART for seq2seq task. During training the loss seems to decreases very well. But when I did inference I am weirdly getting output (like end of sequence tag predicted as first token). Though model.generate() omits those and gives me some result, the result turns out to be suboptimal compared with the loss train loss that we get.

Training loss decreases as follows

 10%|â–ˆ         | 390/3900 [03:16<26:41,  2.19it/s]
Epoch 1 accuracy: 0.3951
 20%|██        | 780/3900 [06:32<23:40,  2.20it/s]
Epoch 2 accuracy: 0.3202
 30%|███       | 1170/3900 [09:49<20:44,  2.19it/s]
Epoch 3 accuracy: 0.2841
 40%|████      | 1560/3900 [13:05<17:45,  2.20it/s]
Epoch 4 accuracy: 0.2239
 50%|█████     | 1950/3900 [16:22<14:49,  2.19it/s]
Epoch 5 accuracy: 0.1809
 60%|██████    | 2340/3900 [19:38<11:49,  2.20it/s]
Epoch 6 accuracy: 0.1484
70%|███████   | 2730/3900 [22:54<08:49,  2.21it/s]
Epoch 7 accuracy: 0.1209
80%|████████  | 3120/3900 [26:10<05:56,  2.19it/s]
Epoch 8 accuracy: 0.0981
 90%|█████████ | 3510/3900 [29:27<02:57,  2.20it/s]
Epoch 9 accuracy: 0.0826
100%|██████████| 3900/3900 [32:43<00:00,  2.20it/s]
Epoch 10 accuracy: 0.0773

Here is a sample result:
outputs = model.generate(input_ids=input_ids)
print(outputs[0])

tensor([[    2,     0, 10845,  1215,  5489,   102,  5885,  3850,  6433,  7681,
          1215, 14760,  6979,  2617,  1215,  4897,  1455,   507,  1215,  4897,
         13908,   131,    94,   638,   131,  1190,   854,  1906,  1116,     6,
           589,     9,  3161,  7586,  7586,  7586, 14760,  7586,  1640,  7586,
          4397,    22, 12714, 12945,  6034,   113,  7586, 29015,  7586,  7586,
          1437,    22,  7586,   113,    22, 15231,   854, 13083,  3243,  7586,
           113,     2]

weirdness1 → Typically the output length is not more than 10/15 tokens. When the above is decoder, I get a lot of noisy tokens. The most weird thing is, the above example belongs to train_dataset!!, its not even eval_dataset.

weirdness2 → look, the first token itself is EOS. id 2 corresponds to EOS

What am i missing out?
Doing labels[labels == tokenizer.pad_token_id] = -100 would prevent computing loss for tokens after EOS in training, So there is no penalising factor for tokens predicted after EOS right?

Here are some specifications of my model

from transformers import AutoTokenizer, BartForConditionalGeneration
import torch


model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

max_source_length = 90
max_target_length = 90

def tokenization_function(batch):
     model_inputs = tokenizer(batch['user_request'], padding="max_length", max_length=max_source_length, truncation=True, return_tensors="pt")
     labels = tokenizer(batch['command'], padding="max_length", max_length=max_target_length, truncation=True, return_tensors="pt")
     model_inputs["decoder_attention_mask"] = labels['attention_mask']
     labels = labels["input_ids"]
     labels[labels == tokenizer.pad_token_id] = -100
     model_inputs["labels"] = labels
     return model_inputs

tokenized_dataset = dataset.map(tokenization_function, batched=True, batch_size=1024)
...
...
...
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
train_loss = [0] * num_epochs
model.train()
for epoch in range(num_epochs):

    for iter, batch in enumerate(train_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        input_ids = input_ids.to(device); attention_mask = attention_mask.to(device); labels = labels.to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        train_loss[epoch]  += loss.item() * input_ids.shape[0]
    
    train_loss[epoch] = train_loss[epoch]/len(train_dataloader.dataset)
        
    print(f'Epoch {epoch+1} accuracy: {train_loss[epoch]:.4f}')

cc: @nielsr
(Apologies for bothering you)

The issue is with respect to my inference script i.e model.generate(input_ids, num_beams=2, max_length=max_target_length). By default min_length is not 0 and it is 50 I believe. After changing it to
model.generate(input_ids, num_beams=2, min_length=0, max_length=max_target_length) the results are as expected.

Not sure why it happens exactly, but changing min_length = 0 fixes the issue

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.