Hello. I’ve been struggle with T5ForConditionalGeneration these days lately. What I want to do is to fine-tuning a T5 model using PyTorch Lightning. They have nice tutorial for it using BERT. They create two classes for dataloader and model. I modify it to suit another dataset that I am using and T5 model.
Now, the problem: everytime I do trainer.fit
it always throw an error ValueError: not enough values to unpack (expected 2, got 1)
from modeling_t5.py line 938 (batch_size, seq_length = input_shape
). Looking into file, I noticed that input_shape
variable is coming from input_ids.size()
and I’m assuming there is one input_ids from dataset that does this. But, checking on dataset input_ids using similar script, there is no error thrown.
Here’s the dataloader class:
class RottenDataModule(LightningDataModule):
field_map = ["text"]
num_classes = 2
loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.text_fields = self.field_map
self.num_labels = self.num_classes
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def setup(self, stage: str):
self.dataset = datasets.load_dataset("rotten_tomatoes", "default")
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
#self.dataset[split]["input_ids"] = self.dataset[split]["input_ids"].unsqueeze(0)
#self.dataset[split]["attention_mask"] = self.dataset[split]["attention_mask"].unsqueeze(0)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("rotten_tomatoes", "default")
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def convert_to_features(self, example_batch, indices=None):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(
texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True
)
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features
And here’s the model class:
class RottenTransformer(LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
train_batch_size: int = 32,
eval_batch_size: int = 32,
eval_splits: Optional[list] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
self.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, config=self.config)
self.metric = datasets.load_metric(
"accuracy", experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
)
def forward(self, **inputs):
a, b = inputs["input_ids"].size()
print(a, b)
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(**batch)
val_loss, logits = outputs[:2]
if self.hparams.num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif self.hparams.num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
return {"loss": val_loss, "preds": preds, "labels": labels}
def validation_epoch_end(self, outputs):
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("val_loss", loss, prog_bar=True)
self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
Here’s the error stack:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-9-ad67c953499b> in <module>
14 devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
15 )
---> 16 trainer.fit(model, datamodule=dm)
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
695 self.strategy.model = model
696 self._call_and_handle_interrupt(
--> 697 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
698 )
699
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
648 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
649 else:
--> 650 return trainer_fn(*args, **kwargs)
651 # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
652 except KeyboardInterrupt as exception:
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
735 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
736 )
--> 737 results = self._run(model, ckpt_path=self.ckpt_path)
738
739 assert self.state.stopped
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1166 self._checkpoint_connector.resume_end()
1167
-> 1168 results = self._run_stage()
1169
1170 log.detail(f"{self.__class__.__name__}: trainer tearing down")
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
1252 if self.predicting:
1253 return self._run_predict()
-> 1254 return self._run_train()
1255
1256 def _pre_training_routine(self):
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1274
1275 with isolate_rng():
-> 1276 self._run_sanity_check()
1277
1278 # enable train mode
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self)
1343 # run eval step
1344 with torch.no_grad():
-> 1345 val_loop.run()
1346
1347 self._call_callback_hooks("on_sanity_check_end")
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
153 if self.num_dataloaders > 1:
154 kwargs["dataloader_idx"] = dataloader_idx
--> 155 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
156
157 # store batch level output per dataloader
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dl_max_batches, kwargs)
141
142 # lightning module methods
--> 143 output = self._evaluation_step(**kwargs)
144 output = self._evaluation_step_end(output)
145
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, **kwargs)
238 """
239 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 240 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
241
242 return output
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
1704
1705 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1706 output = fn(*args, **kwargs)
1707
1708 # restore current_fx when nested context
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py in validation_step(self, *args, **kwargs)
368 with self.precision_plugin.val_step_context():
369 assert isinstance(self.model, ValidationStep)
--> 370 return self.model.validation_step(*args, **kwargs)
371
372 def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
<ipython-input-8-58386b64c860> in validation_step(self, batch, batch_idx, dataloader_idx)
34
35 def validation_step(self, batch, batch_idx, dataloader_idx=0):
---> 36 outputs = self(**batch)
37 val_loss, logits = outputs[:2]
38
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-8-58386b64c860> in forward(self, **inputs)
26 a, b = inputs["input_ids"].size()
27 print(a, b)
---> 28 return self.model(**inputs)
29
30 def training_step(self, batch, batch_idx):
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1649 output_attentions=output_attentions,
1650 output_hidden_states=output_hidden_states,
-> 1651 return_dict=return_dict,
1652 )
1653
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
936 inputs_embeds = self.embed_tokens(input_ids)
937
--> 938 batch_size, seq_length = input_shape
939
940 # required mask seq length can be calculated via length of past
ValueError: not enough values to unpack (expected 2, got 1)
Thank you for your help.