Code snippet that is causing the error:
def compile_torch_to_mhlo(model,data):
print('Compile torch program to mhlo test\n------\n')
import torch_mlir
module = torch_mlir.compile(
model,
data,
output_type=torch_mlir.OutputType.STABLEHLO,
use_tracing = False
)
print(f"StableHLO={module}\n------\n")
if __name__ == '__main__':
...
data = torch.ones(args.batch,3,224,224)
config = AutoConfig.from_pretrained(args.model_type,num_labels=num_classes)
model = CustomViTForImageClassification(config)
model.load_state_dict(checkpoint['state_dict'], strict = False)
compile_torch_to_mhlo(model, data)
Error trace below:
/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/annotations.py:386: UserWarning: TorchScript will treat type annotations of Tensor dtype-specific subtypes as if they are normal Tensors. dtype constraints are not enforced in compilation either.
warnings.warn(
Traceback (most recent call last):
File "/home/nhd7682/SNL_VIT/mpc_inference.py", line 164, in <module>
compile_torch_to_mhlo(model, data)
File "/home/nhd7682/SNL_VIT/mpc_inference.py", line 132, in compile_torch_to_mhlo
module = torch_mlir.compile(
^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch_mlir/__init__.py", line 419, in compile
scripted = torch.jit.script(model)
^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
return torch.jit._recursive.create_script_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
create_methods_and_properties_from_stubs(
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(
RuntimeError:
'NoneType' object has no attribute or method 'expand'.:
File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/transformers/models/vit/modeling_vit.py", line 126
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
I am not sure if this is a bug or I am doing something incorrectly. I am using Torch-MLIR to lower the model to StableHLO which in turn using torch.jit.script method as evident from the error trace.