You’re right, your transform is being called for every sample until the last checkpoint step.
This is done in the Trainer code here
Maybe at that point in the code you can remove the transform and then re-add it ?
To remove the formatting transform you can use dataset.reset_format()
, and then you can set the transform again.