I want to use BetterTransformer to optimize my training loop. However, when I use the vanilla Trainer, the checkpointing operation causes the training to error out with
ValueError: (‘You are trying to save or push a model that has been converted with BetterTransformer
.’, ’ Please revert the model to its original state before calling save_pretrained
or push_to_hub
.', ’ By calling model = BetterTransformer.reverse(model) before saving or pushing.’)
However, there doesn’t seem to be a way to touch the model once passed into the Trainer (on initialization) as callbacks only touch the TrainerControl instance which doesn’t store the model…
Is there any way around this?
Not expert of this but an alternative way to do this coule be edit the source code of trainer.position need to edit. you can try to add model = BetterTransformer.reverse(model)
before this line
Then convert it back into bettertransformer model after saved. model = BetterTransformer.transform(model)
I worked around it with a callback, this specific one is for saving on epochs, but you can probably add one based on steps quite easily.
import os
from transformers import (
TrainingArguments,
TrainerControl,
TrainerState,
TrainerCallback,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SaveBetterTransformerModelCallback(TrainerCallback):
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if args.save_strategy == IntervalStrategy.EPOCH:
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
control.should_save = False # Disable saving in trainer loop
return control
My issue is that I can’t make it work with EarlyStoppingCallback for some reason, not sure why yet