I am looking at the example for torchscripting BERT-like models here: Exporting 🤗 Transformers Models. I have a basic question about the dummy inputs being passed for tracing which don’t make obvious sense to me.
The input passed is a list containing token_ids
and segment_ids
(or token_type_ids
) which torchscript will unpack. Now, BertModel.forward() expects input_ids
and attention_mask
as the first and second arguments respectively. So, how why is segment_ids
being passed as the second argument for both tracing and later on for inference with the loaded torchscripted model? Does it somehow work because of the flag torchscript=True
that’s passed when instantiating the model? If so, how does it work?