Can we convert dynamic DNN model to TorchScript?

Hi all.

I’m trying to convert SwitchTransformer model to TorchScript.
(SwitchTransformer model is MoE DNN based on Google T5 model.)

When converting both T5 and SwitchTransforemer, there’s no error for T5 but I got following error for SwitchTransformer.

/root/HuggingFace/.HF/lib/python3.8/site-packages/transformers/modeling_utils.py:776: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if causal_mask.shape[1] < attention_mask.shape[1]:
Traceback (most recent call last):
  File "example.py", line 423, in <module>
    traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))
  File "/root/HuggingFace/.HF/lib/python3.8/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/root/HuggingFace/.HF/lib/python3.8/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions

I think it is because of the dynamic characteristics of SwitchTransformer.

This is the code for T5.

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small', torchscript = True)
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
attention_mask = input_ids.ne(model.config.pad_token_id).long()
decoder_input_ids = tokenizer('<pad> <extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids

traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))
torch.jit.save(traced_model, "traced_t5.pt")

And this is the code for SwitchTransformer.

from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
from transformers import AutoTokenizer, SwitchTransformersConfig
import torch

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "google/switch-base-8", resume_download=True)
model = SwitchTransformersForConditionalGeneration.from_pretrained(
    "google/switch-base-8",
    resume_download=True, torch_dtype=torch.bfloat16,
    torchscript=True,
)

input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
output_text = "<pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
decoder_input_ids = tokenizer(output_text, return_tensors="pt", padding=True).input_ids

attention_mask = input_ids.ne(model.config.pad_token_id).long()

# model.eval()

traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))