Expected Tensors to be on the same device using Trainer

I get a “RuntimeError: Expected all tensors to be on the same device , but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)” when running the following code.

import datasets
from transformers import AutoTokenizer
import torch
import transformers
from transformers import AutoModelForQuestionAnswering


MAX_LENGTH=200


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
                   
# from: https://huggingface.co/docs/transformers/tasks/question_answering
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=MAX_LENGTH,
        truncation="only_second",  # only truncate the context
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:  # TODO - can throw an error here if the max_length isn't long enough. I think, I can add a check to just end the sequence once max_length is reached
            idx += 1
            if idx >= len(sequence_ids):  # TODO - I added this, but does that make the next if statement unnecessary?
                break
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    return inputs


def load_squad_data():
    from datasets import load_dataset
    squad = load_dataset("squad", split="train[:5000]")
    squad = squad.train_test_split(test_size=0.2)
    tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
    return tokenized_squad



##############################################################
##############################################################
##############################################################
##############################################################
##############################################################

# create the model and tokenizer
model_name = "facebook/opt-6.7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name, device_map='auto')

# load and process the data
data = load_squad_data()

# set up the trainer and train
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=4,
    num_train_epochs=25,
    learning_rate=2e-4,
    output_dir='training_output',
)

data_collator = transformers.DefaultDataCollator()

trainer = transformers.Trainer(
    model=model,
    train_dataset=data['train'],
    args=training_args,
    data_collator=data_collator
)

trainer.train()

The full error is below:

python3 code/min_reproducable.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.12s/it]
Some weights of OPTForQuestionAnswering were not initialized from the model checkpoint at facebook/opt-6.7b and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 2348.15 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2429.45 examples/s]
2024-05-02 16:49:00.310155: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-02 16:49:01.124980: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
  0%|                                                                                                                                                                      | 0/25000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/shenry/llm_for_bioasq/code/min_reproducable.py", line 110, in <module>
    trainer.train()
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/transformers/trainer.py", line 3138, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/transformers/trainer.py", line 3161, in compute_loss
    outputs = model(**inputs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 1436, in forward
    start_loss = loss_fct(start_logits, start_positions)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/modules/loss.py", line 1179, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/shenry/llm_for_bioasq/virtual_env_ne/lib64/python3.9/site-packages/torch/nn/functional.py", line 3059, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
  0%|          | 0/25000 [00:02<?, ?it/s]

I have some guesses as to what is going on, but I am stuck. Any help would be appreciated. Thank you!