torch.nn.DataParallel Mistral-7B-Instruct RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Python 3.9
Model: Mistral-7B-Instruct-v0.2
torch 2.1.1
transformers 4.38.2
NVIDIA-SMI 545.29.06
Driver Version: 545.29.06
CUDA Version: 12.3
2x NVIDIA RTX A4000

class Model(Module):
    def __init__(self,
                 model: MistralForCausalLM,
                 max_new_tokens: int = 512):
        super(Model, self).__init__()
        self.model = model
        self.max_new_tokens = max_new_tokens

    def forward(self,
                inputs):
        with torch.no_grad():
            # Gets past this line on cuda:0, but fails on cuda:1.
            outputs = self.model.generate(**inputs,
                                          max_new_tokens=self.max_new_tokens)

        # Strip the prompt (input) from each prediction.
        predictions = []
        for idx, prediction in enumerate(outputs):
            prediction = prediction[inputs["input_ids"][idx].shape[0]:]
            prediction = pad(prediction,
                             pad=(0, self.max_new_tokens - prediction.shape[0]),
                             mode="constant",
                             value=self.model.generation_config.eos_token_id)
            predictions.append(prediction)

        # Pad all predictions with the pad token id to the max length.
        predictions = stack(predictions,
                            dim=0)
        return predictions
                # Must quantize the model to fit it on 16GB GPU
                bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                                bnb_4bit_use_double_quant=True,
                                                bnb_4bit_quant_type="nf4",
                                                bnb_4bit_compute_dtype=bfloat16)
                self.model = MistralForCausalLM.from_pretrained(model_dir,
                                                                device_map="cuda",
                                                                quantization_config=bnb_config)

        # Create the DataLoader that will feed the model prediction loop.
        dataloader = DataLoader(dataset,
                                batch_size=batch_size)

        # Create a module wrapper for the model that will be used for the prediction loop.
        model_module = Model(self.model,
                             max_new_tokens=max_new_tokens)

        # Will run on both GPUs
        model_module = DataParallel(model_module)

        # Prediction loop.
        preds = []
        for inputs in dataloader:
            preds.extend(model_module(inputs))

As expected, torch.nn.DataParallel loads the model on both GPUs and splits the batch, sending half to each GPU.

The failure occurs at line 85 in torch/nn/parallel/parallel_apply.py:

                output = module(*input, **kwargs)

I spelunked this in the debugger and confirmed that the model and the inputs are BOTH on cuda:1, so the issue is an utter mystery to me.

There is NO issue when it runs on cuda:0.

I have successfully used DataParallel on numerous other HF models. This is a real head scratcher.

BTW, if I remove the DataParallel wrapper it, of course, only runs on cuda:0 and there are no issues.

Any/all insight is greatly appreciated.

Crickets . . .