I am running Mistral model on a remote SSH server. When I am trying to generate the output conditioned by input embeddings, the connection closes unexpectedly. What is interesting, that this behaviour occurs only if I am using Accelerate to try to load model onto multiple GPUs.
I am loading Mistral model as follows:
model = AutoModelForCausalLM.from_pretrained(
"AIRI-Institute/OmniFusion",
subfolder="OmniMistral-v1_1/tuned-model",
torch_dtype=torch.bfloat16,
device_map="auto"
)
Then, I run the generation: model.generate(inputs_embeds=embeddings, max_new_tokens=50)
.
At this point, the notebook freezes (set the logging leven to debug, but the call did not output anything), and I am getting disconnected from the remote server. Additionally checked that the model is indeed loaded onto multiple GPUs and not onto CPU, no errors here.
However, if I load the model onto a single GPU:
model = AutoModelForCausalLM.from_pretrained(
"AIRI-Institute/OmniFusion",
subfolder="OmniMistral-v1_1/tuned-model",
torch_dtype=torch.bfloat16,
device_map="cuda:0"
)
The problem disappears, model.generate(inputs_embeds=embeddings, max_new_tokens=50)
works as expected and finishes in less than 2 seconds.
What could cause such behaviour?
I am not sure whether this question should be opened as an issue, so starting with asking it as a discussion topic. Thanks in advance for your answers!