Hi, it is glad to find the behavior of “Trainer” can be customized by overriding its methods. However, I am facing a problem with the originally existed functions. For example:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: ... if is_sagemaker_mp_enabled(): scaler = self.scaler if self.use_amp else None loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) return loss_mb.reduce_mean().detach().to(self.args.device) ... loss = self.compute_loss(model, inputs) print(loss) # This is the only place I want to change return loss.detach()
This is the part of the code of method “training_step”, which I want to rewrite. Suppose I just want to print the loss in each training step without changing other codes. But apparently, I cannot import the function “is_sagemaker_mp_enabled()”, and thus I have to delete them. I don’t think this is a good solution and is there any elegant way?
Thanks for the help!