New model output types

As was requested in #5226, model outputs are now more informative than just plain tuples (without breaking changes); PyTorch models now return a subclass of ModelOutput that is appropriate. Here is an example on a base model:

from transformers import BertTokenizer, BertForSequenceClassification
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
outputs = model(**inputs, labels=labels)

Then outputs will be an SequenceClassifierOutput object, which has the returned elements as attributes. The old syntax

loss, logits = outputs[:2]

will still work, but you can also do

loss = outputs.loss
logits = outputs.logits

or also

loss = outputs["loss"]
logits = outputs["logits"]

Under the hood, outputs is a dataclass with optional fields that may be set to None if they are not returned by the model (like attentions in our example). If you index by integer or by slice, the None fields are skipped (for backward-compatibility). If you try to access an attribute that’s set to None by its key (for instance here outputs["attentions"]), it will return an error.

You can convert those outputs to a regular tuple/dict with outputs.to_tuple() or outputs.to_dict().

You can revert to the old behavior of having tuple by setting return_tuple=True in the config you pass to your model, or when you instantiate your model, or when you call your model on some inputs. If you’re using torchscript (and the config you passed to your model has config.torchscript = True) this will automatically be the case (because jit only handles tuples as outputs).

Hope you like this new feature!


So many quality of life improvements recently.

Thanks for all your work and effort.


Thanks! This is a really nice improvement :slight_smile:

1 Like

I hadn’t seen this before. This is great! Might help get some confusion out of the way.

FYI, this new feature breaks when using nn.DataParallel (multi gpu) as currently torch.nn.parallel.scatter_gather.gather can’t gather outputs that are dataclasses .

There are workarounds worked on in the core, but until then if you run into this issue in your custom code you may need to add:

+            if isinstance(model, torch.nn.DataParallel):
+                inputs["return_tuple"] = True

before your call:

             outputs = model(**inputs)

For more details see:

p.s. voting on the pytorch issue may help resolve this globally faster.


By voting you mean thumbing-up, right? Done. This would indeed be helpful because I assume that some of the examples may not work with this new functionality out of the box.

(As an aside: nice to see that transformers has grown so much that it is taken into account by the big developer league. :smile:)

Yes and thank you, @BramVanroy.

This would indeed be helpful because I assume that some of the examples may not work with this new functionality out of the box.

Until pytorch sorts this out a potential workaround is being worked on:

(As an aside: nice to see that transformers has grown so much that it is taken into account by the big developer league. :smile:)


1 Like

So after a few weeks of experimenting with this functionality, we have discovered this whole thing is a little bit more complicated than we initially thought.

As mentioned above in this thread, there is a problem with torch.nn.DataParallel. It also causes problem when exporting to ONNX or others. Those can be fixed, but the biggest problem is that TensorFlow does not like those new model output types at all, so we are now left with two different APIs for PyTorch and TensorFlow, which is not what we want for the library.

On the other hand, we have a version that works with TensorFlow and doesn’t break torch.nn.DataParallel in #5981. Its problem is that it breaks backward-compatibility since this model output behaves more like a dict than tuple, and unpacking it (e.g. doing loss, logits = output) iters through the keys and not the values. This breaks pretty much all tests and examples since we use that unpacking a lot.

After lots of internal discussion, and since we can’t find a way to have the best of both world (e.g., fix the current model output without breaking change to work with TensorFlow and torch.nn.DataParallel) we have decided to revert a bit those changes. I’ll start implementing the following tomorrow:

  1. The current ModelOutput will become the more-dict-version in #5981 and it will be used for all TensorFlow models as well as PyTorch.
  2. Since this is a breaking change, we’ll switch the flip on the other side: instead of getting the new model outputs and having the possibility of getting back a tuple with return_tuple=True, we will go back to return tuple by default and have a return_dict argument you can pass to get the new model outputs.
  3. Update all our examples to use return_dict=True and discourage the use of unpacking, so that everyone has time to get use to the future breaking change to come. Test and fix potential side effects.
  4. In a few months (and probably during a major release), switch the flip and make return_dict=True the default for all models

@sgugger What is the guideline for using ModelOutput classes/objects and the created subclasses in our own LM classes? ModelOutput seems to be HF internal class. TIA