Not enough values to unpack (expected 2, got 1) when training with T5ForConditionalGeneration

Hello. I’ve been struggle with T5ForConditionalGeneration these days lately. What I want to do is to fine-tuning a T5 model using PyTorch Lightning. They have nice tutorial for it using BERT. They create two classes for dataloader and model. I modify it to suit another dataset that I am using and T5 model.

Now, the problem: everytime I do trainer.fit it always throw an error ValueError: not enough values to unpack (expected 2, got 1) from modeling_t5.py line 938 (batch_size, seq_length = input_shape). Looking into file, I noticed that input_shape variable is coming from input_ids.size() and I’m assuming there is one input_ids from dataset that does this. But, checking on dataset input_ids using similar script, there is no error thrown.

Here’s the dataloader class:

    class RottenDataModule(LightningDataModule):
        field_map = ["text"]
        num_classes = 2
        loader_columns = [
            "datasets_idx",
            "input_ids",
            "token_type_ids",
            "attention_mask",
            "start_positions",
            "end_positions",
            "labels",
        ]

        def __init__(
            self,
            model_name_or_path: str,
            max_seq_length: int = 128,
            train_batch_size: int = 32,
            eval_batch_size: int = 32,
            **kwargs,
        ):
            super().__init__()
            self.model_name_or_path = model_name_or_path
            self.max_seq_length = max_seq_length
            self.train_batch_size = train_batch_size
            self.eval_batch_size = eval_batch_size

            self.text_fields = self.field_map
            self.num_labels = self.num_classes
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

        def setup(self, stage: str):
            self.dataset = datasets.load_dataset("rotten_tomatoes", "default")

            for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=["label"],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)

            #self.dataset[split]["input_ids"] = self.dataset[split]["input_ids"].unsqueeze(0)
            #self.dataset[split]["attention_mask"] = self.dataset[split]["attention_mask"].unsqueeze(0)

            self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

        def prepare_data(self):
            datasets.load_dataset("rotten_tomatoes", "default")
            AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

        def train_dataloader(self):
            return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)

        def val_dataloader(self):
            if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
            elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

        def test_dataloader(self):
            if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
            elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

        def convert_to_features(self, example_batch, indices=None):
            # Either encode single sentence or sentence pairs
            if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
            else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

            # Tokenize the text/text pairs
            features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True
            )

            # Rename label to labels to make it easier to pass to model forward
            features["labels"] = example_batch["label"]

            return features

And here’s the model class:

class RottenTransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, config=self.config)
        self.metric = datasets.load_metric(
        "accuracy", experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        a, b = inputs["input_ids"].size()
        print(a, b)
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels > 1:
        preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
        preds = logits.squeeze()

        labels = batch["labels"]

        return {"loss": val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.hparams.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=self.hparams.warmup_steps,
        num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

Here’s the error stack:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-9-ad67c953499b> in <module>
    14     devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    15 )
---> 16 trainer.fit(model, datamodule=dm)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    695         self.strategy.model = model
    696         self._call_and_handle_interrupt(
--> 697             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    698         )
    699 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    648                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    649             else:
--> 650                 return trainer_fn(*args, **kwargs)
    651         # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
    652         except KeyboardInterrupt as exception:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    735             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    736         )
--> 737         results = self._run(model, ckpt_path=self.ckpt_path)
    738 
    739         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1166         self._checkpoint_connector.resume_end()
1167 
-> 1168         results = self._run_stage()
1169 
1170         log.detail(f"{self.__class__.__name__}: trainer tearing down")

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
1252         if self.predicting:
1253             return self._run_predict()
-> 1254         return self._run_train()
1255 
1256     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1274 
1275         with isolate_rng():
-> 1276             self._run_sanity_check()
1277 
1278         # enable train mode

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self)
1343             # run eval step
1344             with torch.no_grad():
-> 1345                 val_loop.run()
1346 
1347             self._call_callback_hooks("on_sanity_check_end")

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    198             try:
    199                 self.on_advance_start(*args, **kwargs)
--> 200                 self.advance(*args, **kwargs)
    201                 self.on_advance_end()
    202                 self._restarting = False

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    153         if self.num_dataloaders > 1:
    154             kwargs["dataloader_idx"] = dataloader_idx
--> 155         dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    156 
    157         # store batch level output per dataloader

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    198             try:
    199                 self.on_advance_start(*args, **kwargs)
--> 200                 self.advance(*args, **kwargs)
    201                 self.on_advance_end()
    202                 self._restarting = False

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dl_max_batches, kwargs)
    141 
    142         # lightning module methods
--> 143         output = self._evaluation_step(**kwargs)
    144         output = self._evaluation_step_end(output)
    145 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, **kwargs)
    238         """
    239         hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 240         output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
    241 
    242         return output

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
1704 
1705         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1706             output = fn(*args, **kwargs)
1707 
1708         # restore current_fx when nested context

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py in validation_step(self, *args, **kwargs)
    368         with self.precision_plugin.val_step_context():
    369             assert isinstance(self.model, ValidationStep)
--> 370             return self.model.validation_step(*args, **kwargs)
    371 
    372     def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

<ipython-input-8-58386b64c860> in validation_step(self, batch, batch_idx, dataloader_idx)
    34 
    35   def validation_step(self, batch, batch_idx, dataloader_idx=0):
---> 36     outputs = self(**batch)
    37     val_loss, logits = outputs[:2]
    38 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
1131         # Do not call functions when jit is used
1132         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-8-58386b64c860> in forward(self, **inputs)
    26     a, b = inputs["input_ids"].size()
    27     print(a, b)
---> 28     return self.model(**inputs)
    29 
    30   def training_step(self, batch, batch_idx):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
1131         # Do not call functions when jit is used
1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1649             output_attentions=output_attentions,
1650             output_hidden_states=output_hidden_states,
-> 1651             return_dict=return_dict,
1652         )
1653 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
1131         # Do not call functions when jit is used
1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    936             inputs_embeds = self.embed_tokens(input_ids)
    937 
--> 938         batch_size, seq_length = input_shape
    939 
    940         # required mask seq length can be calculated via length of past

ValueError: not enough values to unpack (expected 2, got 1)

Thank you for your help.