Multi-GPU finetuning of NLLB produces RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0

This code fails on 2 and more GPUs, obviously no matter what version of pre-downloaded NLLB is in my modelPath (I checked NLLB-200-1.3B and NLLB-200-distilled-1.3B).

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.utils.data
from transformers import DataCollatorForSeq2Seq
import evaluate
import numpy as np
from argparse import ArgumentParser

modelPath = "initialmodel"

tokenizer = AutoTokenizer.from_pretrained(modelPath)
model = AutoModelForSeq2SeqLM.from_pretrained(modelPath, device_map="auto")

parser = ArgumentParser()
parser.add_argument('--source-lang', type=str, default='eng_Latn')
parser.add_argument('--target-lang', type=str, default='rus_Cyrl')
parser.add_argument('--delimiter', type=str, default=';')
args = parser.parse_args()

dff = pd.read_csv('dataset/data.csv', sep=args.delimiter)

source = dff[args.source_lang].values.tolist()
target = dff[args.target_lang].values.tolist()

max = 512
X_train, X_val, y_train, y_val = train_test_split(source, target, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=max, return_tensors="pt")
y_train_tokenized = tokenizer(y_train, padding=True, truncation=True, max_length=max, return_tensors="pt")
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=max, return_tensors="pt")
y_val_tokenized = tokenizer(y_val, padding=True, truncation=True, max_length=max, return_tensors="pt")

class ForDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        input_ids = torch.tensor(self.inputs["input_ids"][index]).squeeze()
        target_ids = torch.tensor(self.targets["input_ids"][index]).squeeze()

        return {"input_ids": input_ids, "labels": target_ids}

train_dataset = ForDataset(X_train_tokenized, y_train_tokenized)
test_dataset = ForDataset(X_val_tokenized, y_val_tokenized)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, return_tensors="pt")

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

training_args = Seq2SeqTrainingArguments(
    output_dir="mymodel",
    evaluation_strategy="epoch",
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3, 
    num_train_epochs=20, 
    predict_with_generate=True,
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model('finalmodel')

Text of the shell file used to run my code:
python3 finetune.py --source-lang eng_Latn --target-lang rus_Cyrl --delimiter ';'

My finetuning data (placed in the file dataset/data.csv):

eng_Latn;rus_Cyrl
Mafia is a game which models a conflict between two groups: an informed minority (the mafiosi or the werewolves) and an uninformed majority (the villagers).;Мафия - клубная командная психологическая пошаговая ролевая игра с детективным сюжетом, моделирующая борьбу информированных друг о друге членов организованного меньшинства с неорганизованным большинством.
This is a test, this is another test.;Это тест, это еще одно испытание.
Batman: Arkham City is a computer game developed by Rocksteady Studios and published by Warner Bros. Interactive Entertainment.;Batman: Arkham City — компьютерная игра, разработанная британской студией Rocksteady Studios и изданная компанией Warner Bros. Interactive Entertainment.
The game is presented from the third-person perspective with a primary focus on Batman's combat and stealth abilities, detective skills, and gadgets that can be used in both combat and exploration.;Игра представлена от третьего лица с первичным акцентом на боевые и стелс-способности Бэтмена, детективные навыки и гаджеты, которые могут быть использованы в бою и исследовании.
We can eat and we can drink.;Мы можем есть и мы можем пить.
My test is new.;Мой тест новый.
Punch and Judy is a traditional puppet show featuring Mr. Punch and his wife Judy.;Панч и Джуди — традиционный уличный кукольный театр, центральными персонажами которого являются Панч и его жена Джуди.
The Victorian era was the period of Queen Victoria's reign, from 1837 until 1901.;Викторианская эпоха — период правления королевы Виктории, длившийся с 1837 по 1901 год.
Let's go to the cinema!;Пойдем в кино!
I installed the application long time ago, I need to check according to the time.;Я приложение уже давно установил, надо будет проверить по времени.

Error:

Traceback (most recent call last):
  File "/app/finetune.py", line 106, in <module>
    trainer.train()
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1929, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 2699, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 2731, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 1335, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 1208, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 837, in forward
    layer_outputs = encoder_layer(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 405, in forward
    hidden_states = residual + hidden_states
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

I tried this solution with manual setup of my custom device_map and using dispatch_model with this device_map, but it didn’t help me and RuntimeError replicates again.

@ArthurZ @sgugger
Looks like this issue.
I tried the proposed fix encoder hook but it didn’t help me.
On the other side, they fix NLLB-moe 54B, but I need something analogous for NLLB-200-1.3B. Is there an appropriate solution?
According to the info, NLLB was pending its compatibility with model parallelism, as of April 25. Is there actual info on NLLB now?