Hi all.
I’m trying to convert SwitchTransformer model to TorchScript.
(SwitchTransformer model is MoE DNN based on Google T5 model.)
When converting both T5 and SwitchTransforemer, there’s no error for T5 but I got following error for SwitchTransformer.
/root/HuggingFace/.HF/lib/python3.8/site-packages/transformers/modeling_utils.py:776: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if causal_mask.shape[1] < attention_mask.shape[1]:
Traceback (most recent call last):
File "example.py", line 423, in <module>
traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))
File "/root/HuggingFace/.HF/lib/python3.8/site-packages/torch/jit/_trace.py", line 794, in trace
return trace_module(
File "/root/HuggingFace/.HF/lib/python3.8/site-packages/torch/jit/_trace.py", line 1056, in trace_module
module._c._create_method_from_trace(
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions
I think it is because of the dynamic characteristics of SwitchTransformer.
This is the code for T5.
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small', torchscript = True)
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
attention_mask = input_ids.ne(model.config.pad_token_id).long()
decoder_input_ids = tokenizer('<pad> <extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))
torch.jit.save(traced_model, "traced_t5.pt")
And this is the code for SwitchTransformer.
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
from transformers import AutoTokenizer, SwitchTransformersConfig
import torch
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"google/switch-base-8", resume_download=True)
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8",
resume_download=True, torch_dtype=torch.bfloat16,
torchscript=True,
)
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
output_text = "<pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
decoder_input_ids = tokenizer(output_text, return_tensors="pt", padding=True).input_ids
attention_mask = input_ids.ne(model.config.pad_token_id).long()
# model.eval()
traced_model = torch.jit.trace(model, (input_ids, attention_mask, decoder_input_ids))