Seq2Seq T5-base with Seq2SeqTrainer RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I am replicating this huggingface example: notebooks/summarization.ipynb at main · huggingface/notebooks (github.com)

in my own system with 2 GPUs with my own data that I load as a Huggingface Datasets dataset:

dataset = Dataset.from_pandas(df)

model_name = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name)

max_input_length = 256
max_target_length = 128

def preprocess_function(examples):
    model_inputs = tokenizer(examples["text"], max_length=max_input_length, padding=True, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["label"], max_length=max_target_length, padding=True, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(["text", "label"])

train_test = tokenized_datasets.train_test_split(test_size=0.2)
tokenized_datasets_split = DatasetDict({
    'train': train_test['train'],
    'test': train_test['test']})

train_dataset = tokenized_datasets_split["train"].shuffle(seed=42)
test_dataset = tokenized_datasets_split["test"].shuffle(seed=42)

and am trying to fine tune a t5-base with this data. Training happens fine, but as soon as eval_mode is set I get an error:

args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
     46     return {k: round(v, 4) for k, v in result.items()}
     48 trainer = Seq2SeqTrainer(
     49     model=model,
     50     args=args,
   (...)
     55     compute_metrics=compute_metrics,
     56 )
---> 58 trainer.train()

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:1391, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1388         break
   1390 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 1391 self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1393 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   1394     if is_torch_tpu_available():
   1395         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:1491, in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1489 metrics = None
   1490 if self.control.should_evaluate:
-> 1491     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   1492     self._report_to_hp_search(trial, epoch, metrics)
   1494 if self.control.should_save:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:75, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, max_length, num_beams)
     73 self._max_length = max_length if max_length is not None else self.args.generation_max_length
     74 self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
---> 75 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:2113, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   2110 start_time = time.time()
   2112 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 2113 output = eval_loop(
   2114     eval_dataloader,
   2115     description="Evaluation",
   2116     # No point gathering the predictions if there are no metrics, otherwise we defer to
   2117     # self.args.prediction_loss_only
   2118     prediction_loss_only=True if self.compute_metrics is None else None,
   2119     ignore_keys=ignore_keys,
   2120     metric_key_prefix=metric_key_prefix,
   2121 )
   2123 total_batch_size = self.args.eval_batch_size * self.args.world_size
   2124 output.metrics.update(
   2125     speed_metrics(
   2126         metric_key_prefix,
   (...)
   2130     )
   2131 )

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:2285, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   2282         batch_size = observed_batch_size
   2284 # Prediction step
-> 2285 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   2287 # Update containers on host
   2288 if loss is not None:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:167, in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
    160 # XXX: adapt synced_gpus for fairscale as well
    161 gen_kwargs = {
    162     "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
    163     "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
    164     "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
    165 }
--> 167 generated_tokens = self.model.generate(
    168     inputs["input_ids"],
    169     attention_mask=inputs["attention_mask"],
    170     **gen_kwargs,
    171 )
    172 # in case the batch is shorter than max length, the output should be padded
    173 if generated_tokens.shape[-1] < gen_kwargs["max_length"]:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/generation_utils.py:922, in GenerationMixin.generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
    918 encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
    920 if self.config.is_encoder_decoder:
    921     # add encoder_outputs to model_kwargs
--> 922     model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
    924     # set input_ids as decoder_input_ids
    925     if "decoder_input_ids" in model_kwargs:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/generation_utils.py:417, in GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs)
    411     encoder = self.get_encoder()
    412     encoder_kwargs = {
    413         argument: value
    414         for argument, value in model_kwargs.items()
    415         if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
    416     }
--> 417     model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
    418 return model_kwargs

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py:904, in T5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    902 if inputs_embeds is None:
    903     assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
--> 904     inputs_embeds = self.embed_tokens(input_ids)
    906 batch_size, seq_length = input_shape
    908 # required mask seq length can be calculated via length of past

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/sparse.py:158, in Embedding.forward(self, input)
    157 def forward(self, input: Tensor) -> Tensor:
--> 158     return F.embedding(
    159         input, self.weight, self.padding_idx, self.max_norm,
    160         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/functional.py:2183, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2177     # Note [embedding_renorm set_grad_enabled]
   2178     # XXX: equivalent to
   2179     # with torch.no_grad():
   2180     #   torch.embedding_renorm_
   2181     # remove once script supports set_grad_enabled
   2182     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2183 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

Training happens in GPU I have confirmed, so I am not sure what’s left in CPU for this error to appear. Any guidance would be appreciated.