How to jit.trace gpt-neo-125mb

Right now im doing this:

inputs = torch.tensor([tokenizer.encode(“The Manhattan bridge”)])

traced_script_module = torch.jit.trace(model, inputs )

And got this error:
Tracer cannot infer type of CausalLMOutputWithPast(loss=None, logits=tensor([[[ -7.3835, -6.2460, -8.1929, ...

cc @valhalla

Any updates here? Can Causal models be traced with torch.jit?

any updates?