I’m working with deepspeed and transformers for distributed inference. The model is LLAMA-7B@FP16. I’m loading model with AutoModelForCausalLM.from_config
.
I want to profile decoding stage. So I have to add torch.profiler
around decoding and omit prefilling. What I suppose could achieve this is to modify funtions in transformers.generation.utils
such as sample()
or greedy_search()
. But I don’t wanna to change source code of transfomers
.
Any idea that I could add torch.profiler
in greedy_search()
? Or any other better method that I cound profiling decoding stage?