Hi everyone,
I want to do text classification using T5, but I encountered an error during training. Here is my code:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
# Load the dataset
dataset = load_dataset("dair-ai/emotion")
# Subset the dataset to only 100 samples
train_dataset = dataset["train"].shuffle(seed=42).select(range(100))
validation_dataset = dataset["validation"].shuffle(seed=42).select(range(20)) # Use 20 samples for validation
# Load the tokenizer and model
model_name = "google-t5/t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Mapping numerical labels to text
label_map = {0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}
# Tokenize the dataset
def preprocess_function(examples):
inputs = [f"emotion: {text}" for text in examples["text"]]
targets = [label_map[label] for label in examples["label"]] # Convert label to text
print(targets)
model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=3, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Apply preprocessing to the subsetted datasets
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_validation_dataset = validation_dataset.map(preprocess_function, batched=True)
# Debug: Print a sample to check structure
print(tokenized_train_dataset[0])
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8, # Reduce batch size for smaller dataset
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=2,
save_steps=50, # Save more frequently for small dataset
logging_dir="./logs",
logging_steps=10,
report_to="none"
)
# Define the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_validation_dataset,
tokenizer=tokenizer,
)
# Fine-tune the model
trainer.train()
# Save the model
trainer.save_model("./fine-tuned-t5-small-emotion-100-samples")
tokenizer.save_pretrained("./fine-tuned-t5-small-emotion-100-samples")
Here is the error traceback:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-18-1337150a5d0c> in <cell line: 0>()
62
63 # Fine-tune the model
---> 64 trainer.train()
65
66 # Save the model
9 frames
/usr/local/lib/python3.11/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2169 hf_hub_utils.enable_progress_bars()
2170 else:
-> 2171 return inner_training_loop(
2172 args=args,
2173 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.11/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2529 )
2530 with context():
-> 2531 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2532
2533 if (
/usr/local/lib/python3.11/dist-packages/transformers/trainer.py in training_step(self, model, inputs, num_items_in_batch)
3673
3674 with self.compute_loss_context_manager():
-> 3675 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3676
3677 del inputs
/usr/local/lib/python3.11/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3729 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3730 inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
3732 # Save past state if it exists
3733 # TODO: this needs to be fixed and made cleaner later.
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1889
1890 # Decode
-> 1891 decoder_outputs = self.decoder(
1892 input_ids=decoder_input_ids,
1893 attention_mask=decoder_attention_mask,
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1002 inputs_embeds = self.embed_tokens(input_ids)
1003
-> 1004 batch_size, seq_length = input_shape
1005
1006 if use_cache is True:
ValueError: not enough values to unpack (expected 2, got 1)
I suspect the issue might be with how I process the labels, but I’m not sure what exactly is causing it. Could someone explain this error and how to fix it?
Thank you!