Torchscript Example for BERT

I am looking at the example for torchscripting BERT-like models here: Exporting 🤗 Transformers Models. I have a basic question about the dummy inputs being passed for tracing which don’t make obvious sense to me.

The input passed is a list containing token_ids and segment_ids (or token_type_ids) which torchscript will unpack. Now, BertModel.forward() expects input_ids and attention_mask as the first and second arguments respectively. So, how why is segment_ids being passed as the second argument for both tracing and later on for inference with the loaded torchscripted model? Does it somehow work because of the flag torchscript=True that’s passed when instantiating the model? If so, how does it work?

cc’ing @lewtun here

@lewtun any insights here?

Hey @hikushalhere thanks for raising this issue! This looks like an error in the guide and the only reason the code runs is because the tensor used for segment_ids is similar to what attention_mask should be.

The torchscript=True flag is used to ensure the model outputs are tuples instead of ModelOutput (which causes JIT errors).

Would you like to open a PR to fix the guide?