I have a non-KV Cache T5 Onnx Model that I’m trying to load. There are two issues
with CUDAExecutionProvider
model = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="CUDAExecutionProvider", task="text2text-generation", use_cache=False, use_io_binding=False, device_map="cuda:3")
The same thing happens with a model that has with_past
:
model_kv = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="CUDAExecutionProvider", task="text2text-generation-with-past", use_cache=True, use_io_binding=True, device_map="cuda:3")
Both load loads fine, except that both ignores device_map
and loads it to cuda:0
, which means I can’t load multiple models two models for testing or serving since my model is relatively large and can only fit one on a gpu.
with TensorrtExecutionProvider
Then I tried TensorrtExecutionProvider
(on non KV cache model). The model thinks that its loaded into the wrong cuda, but it’s actually on CPU. Because the model in reality is on CPU, its immensely slow to load.
model = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="TensorrtExecutionProvider", task="text2text-generation", use_cache=False, use_io_binding=False, device_map="cuda:3")
model.device
shows device(type='cuda', index=0)
. However, my nvitop
shows that there is hardly any memory use on cuda:0
(3746MB
used up from my base of 872MB
- it could be that the model graph is indeed loaded onto cuda_0 but the data isn’t). If I look into my cpu main memory, I see that utilization has increased by 60GB
, the size of my model. So the model that I wanted to load into cuda:3
thinks its on cuda:0
but is actually on CPU.
I have this in my os environment:
export ONNX_MODE=cuda
export PATH=/usr/local/cuda-11.8/bin:$PATH
export CUDA_PATH=/usr/local/cuda
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:/home/zm/workspace/packages/TensorRT-8.6.1.6/lib:$LD_LIBRARY_PATH
Wondering how I can fix this? What I want in both cases is to be able to put the model on what ever device (cuda0-7) that I want. I can get them there with a .to("cuda:3")
but this essentially causes me to do a total-reload which is slow.