Some functions when customizing trainer

Hi, it is glad to find the behavior of “Trainer” can be customized by overriding its methods. However, I am facing a problem with the originally existed functions. For example:

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
    ...
    if is_sagemaker_mp_enabled():
            scaler = self.scaler if self.use_amp else None
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
            return loss_mb.reduce_mean().detach().to(self.args.device)
    ...
    loss = self.compute_loss(model, inputs)
    print(loss) # This is the only place I want to change
    return loss.detach()

This is the part of the code of method “training_step”, which I want to rewrite. Suppose I just want to print the loss in each training step without changing other codes. But apparently, I cannot import the function “is_sagemaker_mp_enabled()”, and thus I have to delete them. I don’t think this is a good solution and is there any elegant way?

Thanks for the help!

There is no reason you shouldn’t be able to import is_sagemaker_mp_enabled from its location (transformers.file_utils).

Thanks! Now I can add extra logic without changing the original codes.