RuntimeError: Input, output and indices must be on the current device

My training fails on this line: outputs = model(**{key: torch.squeeze(value.to(self.training_device), 1) for key, value in inputs.items()})

I try training LayoutLMV2 model with torch.nn.DataParallel(). I use a single node instance with 4 Gpus in AWS Sagemaker. However, when the code gets to the line above it fails with the specified error in the title.

I tried having a look online and the suggestions seem to indicate to move both model and inputs to Gpu, but I think I already do that. Also, I cannot call inputs = inputs.to(self.training_device) because inputs is a dictionary: dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'labels', 'image'])

This is my Trainer class:


class LanguageModelTrainer(Trainer):
    def __init__(self, *args, class_weights=None, l1_coef=0, **kwargs):
        self.class_weights = class_weights
        self.l1_coef = l1_coef
        self.training_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        model = torch.nn.DataParallel(model)
        model.to(self.training_device)

        loss_fn = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        outputs = model(**{key: torch.squeeze(value.to(self.training_device),
                                              1) for key, value in inputs.items()})
        loss = loss_fn(outputs.logits.view(-1, model.config.num_labels), inputs["labels"].view(-1))

        if self.l1_coef:
            l1_norm = sum(torch.linalg.norm(p.flatten(), 1) for n, p in model.named_parameters() if 'bias' not in n)
            loss += self.l1_coef * l1_norm
        
        return (loss, outputs) if return_outputs else loss

For DataParallel part I just followed this tutorial: Optional: Data Parallelism — PyTorch Tutorials 1.13.1+cu117 documentation

I would be thankful if someone could point what is the problem and how to fix it?