Here I remark that the output of individual sequences are different from batched sequences using T5ForConditionalGeneration
Here is an example to reproduce the result: batch_size=1 vs 2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from torchtext.legacy.data import Field, BucketIterator,
TabularDataset
# prepare data
data = {"text": ["summarize: i am very happy. i am very happy",
"summarize: i am very safe. i am very safe"],
"summary": ["i am very happy", "i am very safe"]}
df = pd.DataFrame(data)
df.to_csv("debug.csv", index=False)
# set tokenizer of T5-small
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("t5-small")
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.resize_token_embeddings(len(tokenizer))
model.to("cuda")
from transformers import T5Tokenizer, T5ForConditionalGeneration
SRC = Field(tokenize = tokenizer.encode,
use_vocab=False,
lower = False,
init_token = None,
eos_token = eos_index,
pad_token=pad_index,
unk_token=unk_index,
include_lengths = True)
TRG = Field(tokenize = tokenizer.encode,
use_vocab=False,
init_token = None,
eos_token = eos_index,
pad_token=pad_index,
unk_token=unk_index,
include_lengths = True,
lower = False)
fields = {"text": ("src", SRC), "summary": ("trg", TRG)}
train_data, valid_data, test_data = TabularDataset.splits(
path="./",
train="debug.csv",
validation="debug.csv",
test="debug.csv",
format='csv',
fields=fields)
BATCH_SIZE = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
sort_within_batch = True,
sort_key = lambda x : len(x.src),
device = device)
for i, batch in enumerate(train_iterator):
src, src_len = batch.src
trg, trg_len= batch.trg
print(trg_len)
#attention_mask = torch.ones((2, 15)).to("cuda")
#decoder_attention_mask=torch.ones((2, 7)).to("cuda")
logits = model(input_ids=src.view(src.shape[1], src.shape[0]),
labels=trg.view(trg.shape[1], trg.shape[0]),
attention_mask=None,
decoder_attention_mask=None
).logits
X = logits.view(logits.size(1), logits.size(0), logits.size(-1))
X = F.softmax(X, dim=-1)
ids = X.argmax(dim=-1)
y = tokenizer.batch_decode(sequences=ids, skip_special_tokens=True)
z0 = tokenizer.batch_decode(sequences=trg, skip_special_tokens=True)
print(" ".join(y))
print("*********")
print(" ".join(z0))
print("*********")