Deepspeed trainer and custom loss weights

Hi all, I wrote a custom loss as suggested in this forum:

loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights_pt)
class SentTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss
        labels = inputs.get("labels")
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

However, when integrating deepspeed, it complained about weight tensor not on the same device. How should I fix this issue?

I think you can define loss_fct inside __init__ function and then set class_weights_pt to the appropriate device.

    def __init__(self, *args, class_weights: Optional[FloatTensor] = None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            class_weights = class_weights.to(self.accelerator.device)
        self.loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)