### 🐛 Describe the bug
I'd like to compile a function ùsing `fullgraph=True` …where the function call makes forward passes in a module which checks if we're in inference mode using either `torch.is_inference_mode_enabled`, or `Tensor.is_inference()`, but I'm getting the above error. This is to disambiguate forward calls which are both in `torch.no_grad()`, but one forward pass should behave differently - in my case it's for using KV-cacheing in a compiled generation function.
### Error logs
```bash
from user code:
File "/home/salman/torchtune/torchtune/generation/_generation.py", line 84, in generate_next_token
logits = model(x, input_pos=input_pos, mask=mask, cache_pos=cache_pos)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/home/salman/torchtune/torchtune/modules/transformer.py", line 466, in forward
h = layer(
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/home/salman/torchtune/torchtune/modules/transformer.py", line 107, in forward
attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos, cache_pos=cache_pos)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/home/salman/torchtune/torchtune/modules/attention.py", line 268, in forward
if self.kv_cache is not None and x.is_inference():
```
and
```bash
8 16:22:27.745000 364546 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_source] TRACE starts_line /home/salman/torchtune/torchtune/modules/attention.py:268 in forward (MultiHeadAttention.forward) (inline depth: 6)
V0908 16:22:27.745000 364546 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_source] if self.kv_cache is not None and x.is_inference():
V0908 16:22:27.745000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE LOAD_FAST self []
V0908 16:22:27.745000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE LOAD_ATTR kv_cache [NNModuleVariable()]
V0908 16:22:27.745000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [NNModuleVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE IS_OP 0 [NNModuleVariable(), ConstantVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE POP_JUMP_FORWARD_IF_TRUE 1302 [ConstantVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE LOAD_METHOD is_inference [TensorVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE PRECALL 0 [NullVariable(), GetAttrVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:815] [0/0] [__trace_bytecode] TRACE CALL 0 [NullVariable(), GetAttrVariable()]
V0908 16:22:27.746000 364546 torch/_dynamo/output_graph.py:1908] [0/0] [__trace_call] TRACE FX call is_inference from /home/salman/torchtune/torchtune/modules/attention.py:268 in forward (MultiHeadAttention.forward) (inline depth: 6)
V0908 16:22:27.746000 364546 torch/_dynamo/output_graph.py:1908] [0/0] [__trace_call] if self.kv_cache is not None and x.is_inference():
V0908 16:22:27.746000 364546 torch/_dynamo/output_graph.py:1908] [0/0] [__trace_call] ~~~~~~~~~~~~~~^^
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object forward at 0x5626a15d7d10, file "/home/salman/torchtune/torchtune/modules/attention.py", line 161>
V0908 16:22:27.746000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object _call_impl at 0x56269c10a010, file "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1720>
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object forward at 0x7f40483e57c0, file "/home/salman/torchtune/torchtune/modules/transformer.py", line 67>
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object _call_impl at 0x56269c10a010, file "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1720>
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object forward at 0x5626a1579290, file "/home/salman/torchtune/torchtune/modules/transformer.py", line 374>
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:2954] [0/0] FAILED INLINING <code object _call_impl at 0x56269c10a010, file "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1720>
V0908 16:22:27.747000 364546 torch/_dynamo/symbolic_convert.py:831] [0/0] empty checkpoint
Traceback (most recent call last):
File "/home/salman/torchtune/target/generate_test.py", line 239, in <module>
generated_tokens, t = generate_with_compile()
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/torchtune/target/generate_test.py", line 127, in generate_with_compile
generation.generate(
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/torchtune/torchtune/generation/_generation.py", line 288, in generate
tokens, logits = custom_generate_next_token(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 448, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1171, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in __call__
return _compile(
^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 851, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 669, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 195, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 611, in transform
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2609, in run
super().run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in CALL
self._call(inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2144, in _call
self.call_function(fn, args, kwargs)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 132, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 433, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1563, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in CALL
self._call(inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2144, in _call
self.call_function(fn, args, kwargs)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 433, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1563, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in CALL
self._call(inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2144, in _call
self.call_function(fn, args, kwargs)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 433, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1563, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2824, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2940, in inline_call_
tracer.run()
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
while self.step():
^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in CALL
self._call(inst)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2144, in _call
self.call_function(fn, args, kwargs)
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 745, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 507, in call_method
return wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1863, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2152, in wrap_fx_proxy_cls
unimplemented(
File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 229, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_method is_inference
```
### Minified repro
_No response_
### Versions
```bash
Collecting environment information...
PyTorch version: 2.5.0.dev20240725+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.11.9 (main, Jul 25 2024, 15:18:30) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 SUPER
Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 12
On-line CPU(s) list: 0-11
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 113
Model name: AMD Ryzen 5 3600 6-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 2200.000
CPU max MHz: 3600.0000
CPU min MHz: 2200.0000
BogoMIPS: 7187.06
Virtualisation: AMD-V
L1d cache: 192 KiB
L1i cache: 192 KiB
L2 cache: 3 MiB
L3 cache: 32 MiB
NUMA node0 CPU(s): 0-11
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240725+cu121
[pip3] torchao==0.4.0
[pip3] torchao-nightly==2024.8.16+cu121
[pip3] torchtune==0.0.0
[pip3] torchvision==0.20.0.dev20240725+cu121
```
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec