Unable to torch.jit.trace quantized BigBird (0INTERNAL ASSERT FAILED runtime error) but works for BERT and RoBERTa

Hello,
I am trying to torch.jit.trace Transformers’ implementation of BigBird. But I’m encountering a runtime error that I’m not very familiar with, specifically:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-1dfdd2340788> in <module>
      4 )
      5 
----> 6 traced_model = torch.jit.trace(model, (input_ids, attention_mask))
      7 torch.jit.save(traced_model, "traced_bigbird.pt")

/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    742             strict,
    743             _force_outplace,
--> 744             _module_class,
    745         )
    746 

/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    957                 strict,
    958                 _force_outplace,
--> 959                 argument_names,
    960             )
    961             check_trace_method = module._c._get_method(method_name)

RuntimeError: 0INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/alias_analysis.cpp":532, please report a bug to PyTorch. We don't have an op for aten::constant_pad_nd but it isn't a special case.  Argument types: Tensor, int[], bool, 

I’m trying to locate constant_pad_nd in the code base for BigBird to figure out how it relates to aten, but i’m also that familiar with aten. That said, I also ran the same code for BERT and RoBERTa but did not encounter the same issue and was able to trace the quantized models for both respectively.

To reproduce this error,

  1. Git clone this repo
  2. Run example.ipynb

Anyone familiar with this matter or knows enough to help debug this issue?

Created a Github issue here.
Also tagging @patrickvonplaten @sgugger @lewtun for more visibility :slight_smile: