Python 3.9
Model: Mistral-7B-Instruct-v0.2
torch 2.1.1
transformers 4.38.2
NVIDIA-SMI 545.29.06
Driver Version: 545.29.06
CUDA Version: 12.3
2x NVIDIA RTX A4000
class Model(Module):
def __init__(self,
model: MistralForCausalLM,
max_new_tokens: int = 512):
super(Model, self).__init__()
self.model = model
self.max_new_tokens = max_new_tokens
def forward(self,
inputs):
with torch.no_grad():
# Gets past this line on cuda:0, but fails on cuda:1.
outputs = self.model.generate(**inputs,
max_new_tokens=self.max_new_tokens)
# Strip the prompt (input) from each prediction.
predictions = []
for idx, prediction in enumerate(outputs):
prediction = prediction[inputs["input_ids"][idx].shape[0]:]
prediction = pad(prediction,
pad=(0, self.max_new_tokens - prediction.shape[0]),
mode="constant",
value=self.model.generation_config.eos_token_id)
predictions.append(prediction)
# Pad all predictions with the pad token id to the max length.
predictions = stack(predictions,
dim=0)
return predictions
# Must quantize the model to fit it on 16GB GPU
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bfloat16)
self.model = MistralForCausalLM.from_pretrained(model_dir,
device_map="cuda",
quantization_config=bnb_config)
# Create the DataLoader that will feed the model prediction loop.
dataloader = DataLoader(dataset,
batch_size=batch_size)
# Create a module wrapper for the model that will be used for the prediction loop.
model_module = Model(self.model,
max_new_tokens=max_new_tokens)
# Will run on both GPUs
model_module = DataParallel(model_module)
# Prediction loop.
preds = []
for inputs in dataloader:
preds.extend(model_module(inputs))
As expected, torch.nn.DataParallel loads the model on both GPUs and splits the batch, sending half to each GPU.
The failure occurs at line 85 in torch/nn/parallel/parallel_apply.py:
output = module(*input, **kwargs)
I spelunked this in the debugger and confirmed that the model and the inputs are BOTH on cuda:1, so the issue is an utter mystery to me.
There is NO issue when it runs on cuda:0.
I have successfully used DataParallel on numerous other HF models. This is a real head scratcher.
BTW, if I remove the DataParallel wrapper it, of course, only runs on cuda:0 and there are no issues.
Any/all insight is greatly appreciated.