How to modify Model Class with AutoModelForCausalLM.from_config

I’m working with deepspeed and transformers for distributed inference. The model is LLAMA-7B@FP16. I’m loading model with AutoModelForCausalLM.from_config.

I want to profile decoding stage. So I have to add torch.profiler around decoding and omit prefilling. What I suppose could achieve this is to modify funtions in transformers.generation.utils such as sample() or greedy_search(). But I don’t wanna to change source code of transfomers.

Any idea that I could add torch.profiler in greedy_search()? Or any other better method that I cound profiling decoding stage?