BetterTransformer with HF Trainer

I want to use BetterTransformer to optimize my training loop. However, when I use the vanilla Trainer, the checkpointing operation causes the training to error out with

ValueError: (‘You are trying to save or push a model that has been converted with BetterTransformer.’, ’ Please revert the model to its original state before calling save_pretrained or push_to_hub.', ’ By calling model = BetterTransformer.reverse(model) before saving or pushing.’)

However, there doesn’t seem to be a way to touch the model once passed into the Trainer (on initialization) as callbacks only touch the TrainerControl instance which doesn’t store the model…

Is there any way around this?

Not expert of this but an alternative way to do this coule be edit the source code of trainer.position need to edit. you can try to add model = BetterTransformer.reverse(model) before this line :smiley: Then convert it back into bettertransformer model after saved. model = BetterTransformer.transform(model)

I worked around it with a callback, this specific one is for saving on epochs, but you can probably add one based on steps quite easily.

import os
from transformers import (
    TrainingArguments,
    TrainerControl,
    TrainerState,
    TrainerCallback,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy


class SaveBetterTransformerModelCallback(TrainerCallback):
    def on_epoch_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if args.save_strategy == IntervalStrategy.EPOCH:
            control.should_save = True

        if control.should_save:
            checkpoint_folder = os.path.join(
                args.output_dir,
                f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
            )

            model = BetterTransformer.reverse(kwargs["model"])
            model.save_pretrained(checkpoint_folder)
            control.should_save = False  # Disable saving in trainer loop
        return control

My issue is that I can’t make it work with EarlyStoppingCallback for some reason, not sure why yet