Dealing with multiple sequences in T5ForConditionalGeneration

Here I remark that the output of individual sequences are different from batched sequences using T5ForConditionalGeneration

Here is an example to reproduce the result: batch_size=1 vs 2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from torchtext.legacy.data import Field, BucketIterator,
TabularDataset

# prepare data
data = {"text": ["summarize: i am very happy. i am very happy",
        "summarize: i am very safe. i am very safe"],
        "summary": ["i am very happy", "i am very safe"]}
df = pd.DataFrame(data)
df.to_csv("debug.csv", index=False)

# set tokenizer of T5-small
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("t5-small")
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

model.resize_token_embeddings(len(tokenizer))
model.to("cuda")

from transformers import T5Tokenizer, T5ForConditionalGeneration

SRC = Field(tokenize = tokenizer.encode, 
            use_vocab=False,
            lower = False,
            init_token = None, 
            eos_token = eos_index, 
            pad_token=pad_index,
            unk_token=unk_index,
            include_lengths = True)

TRG = Field(tokenize = tokenizer.encode, 
            use_vocab=False,
            init_token = None, 
            eos_token = eos_index, 
            pad_token=pad_index,
            unk_token=unk_index,
            include_lengths = True,
            lower = False)


fields = {"text": ("src", SRC), "summary": ("trg", TRG)}
train_data, valid_data, test_data = TabularDataset.splits(
    path="./",
    train="debug.csv",
    validation="debug.csv",
    test="debug.csv",
    format='csv',
    fields=fields)

BATCH_SIZE = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     sort_within_batch = True,
     sort_key = lambda x : len(x.src),
     device = device)

for i, batch in enumerate(train_iterator):
    src, src_len = batch.src
    trg, trg_len= batch.trg
    print(trg_len)
    #attention_mask = torch.ones((2, 15)).to("cuda")
    #decoder_attention_mask=torch.ones((2, 7)).to("cuda")
    
    logits = model(input_ids=src.view(src.shape[1], src.shape[0]),
                   labels=trg.view(trg.shape[1], trg.shape[0]),
                   attention_mask=None,
                   decoder_attention_mask=None
                   ).logits
    X = logits.view(logits.size(1), logits.size(0), logits.size(-1))
    X = F.softmax(X, dim=-1)
    ids = X.argmax(dim=-1)
    y = tokenizer.batch_decode(sequences=ids, skip_special_tokens=True)
    z0 = tokenizer.batch_decode(sequences=trg, skip_special_tokens=True)
    print(" ".join(y))
    print("*********")
    print(" ".join(z0))
    print("*********")