Device_map not wokring for ORTModelForSeq2SeqLM - Potential bug?

I have a non-KV Cache T5 Onnx Model that I’m trying to load. There are two issues

with CUDAExecutionProvider

model = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="CUDAExecutionProvider", task="text2text-generation", use_cache=False, use_io_binding=False, device_map="cuda:3")

The same thing happens with a model that has with_past:

model_kv = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="CUDAExecutionProvider", task="text2text-generation-with-past", use_cache=True, use_io_binding=True, device_map="cuda:3")

Both load loads fine, except that both ignores device_map and loads it to cuda:0, which means I can’t load multiple models two models for testing or serving since my model is relatively large and can only fit one on a gpu.

with TensorrtExecutionProvider

Then I tried TensorrtExecutionProvider (on non KV cache model). The model thinks that its loaded into the wrong cuda, but it’s actually on CPU. Because the model in reality is on CPU, its immensely slow to load.

model = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="TensorrtExecutionProvider", task="text2text-generation", use_cache=False, use_io_binding=False, device_map="cuda:3")

model.device shows device(type='cuda', index=0). However, my nvitop shows that there is hardly any memory use on cuda:0 (3746MB used up from my base of 872MB - it could be that the model graph is indeed loaded onto cuda_0 but the data isn’t). If I look into my cpu main memory, I see that utilization has increased by 60GB, the size of my model. So the model that I wanted to load into cuda:3 thinks its on cuda:0 but is actually on CPU.

I have this in my os environment:

export ONNX_MODE=cuda
export PATH=/usr/local/cuda-11.8/bin:$PATH
export CUDA_PATH=/usr/local/cuda
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:/home/zm/workspace/packages/TensorRT-$LD_LIBRARY_PATH

Wondering how I can fix this? What I want in both cases is to be able to put the model on what ever device (cuda0-7) that I want. I can get them there with a .to("cuda:3") but this essentially causes me to do a total-reload which is slow.