Fine tuning CLIP cannot evaluate at step and calculate validation loss

I want to fine tune clip and it works when evaluate_strategy is “epoch”.
But when I want to see what the loss is against the evaluation dataset I want to set evaluate_strategy at “steps” but when I do this I get the following error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-18-507c166e7f20> in <module>
      1 from transformers.trainer_utils import get_last_checkpoint
      2 # train_result = trainer.train(resume_from_checkpoint=get_last_checkpoint("./checkpoints"))
----> 3 train_result = trainer.train()

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1496             self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1497         )
-> 1498         return inner_training_loop(
   1499             args=args,
   1500             resume_from_checkpoint=resume_from_checkpoint,

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1815                     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
   1816 
-> 1817                     self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1818                 else:
   1819                     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2036         metrics = None
   2037         if self.control.should_evaluate:
-> 2038             metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2039             self._report_to_hp_search(trial, self.state.global_step, metrics)
   2040 

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   2756 
   2757         eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 2758         output = eval_loop(
   2759             eval_dataloader,
   2760             description="Evaluation",

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   2934 
   2935             # Prediction step
-> 2936             loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   2937             inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
   2938 

/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   3197             return (loss, None, None)
   3198 
-> 3199         logits = nested_detach(logits)
   3200         if len(logits) == 1:
   3201             logits = logits[0]

/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
    157     "Detach `tensors` (even if it's a nested list/tuple of tensors)."
    158     if isinstance(tensors, (list, tuple)):
--> 159         return type(tensors)(nested_detach(t) for t in tensors)
    160     return tensors.detach()
    161 

/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py in <genexpr>(.0)
    157     "Detach `tensors` (even if it's a nested list/tuple of tensors)."
    158     if isinstance(tensors, (list, tuple)):
--> 159         return type(tensors)(nested_detach(t) for t in tensors)
    160     return tensors.detach()
    161 

/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
    158     if isinstance(tensors, (list, tuple)):
    159         return type(tensors)(nested_detach(t) for t in tensors)
--> 160     return tensors.detach()
    161 
    162 

AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach'

To overcome this I should be able to pass a list of keys that can be ignored during the evaluation_loop.
After training you can do this by the following statement:

trainer.evaluate(ignore_keys=["text_model_output", "vision_model_output", "text_embeds", "logits_per_image"])

But how can I do this while training?

This is my Trainer object:


trainer = Trainer(
    model=model,
    args=TrainingArguments(output_dir="./checkpoints",
                           weight_decay=0.1,
                           dataloader_num_workers=0,
                           per_device_eval_batch_size=8,
                           per_device_train_batch_size=8,
                           num_train_epochs=4,
                           evaluation_strategy = "steps",
                           eval_steps=8,
                           warmup_steps=0,
                           learning_rate=5e-05,
                           report_to="wandb",
                           ),
    train_dataset=train_dataset,
    eval_dataset=train_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=processor
)

@sgugger, I inspired myself partially on your example. This really helped me get started with fine-tuning CLIP, thanks for that!!
If you have any idea how I can get the validation loss during training time, that would be fantastic!

Hello Vincent

Any update about this error? :smile:
I have a similar issue trying to replicate the class ViTForimageCllassification.

best,
Cristóbal

Nope unfortunately not :frowning:

As an update if anyone runs into this later, you can pass the list using ignore_keys_for_eval when calling Trainer.train() as seen in the source code here

   def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
        ignore_keys_for_eval: Optional[List[str]] = None,
        **kwargs,
    ):