Generating text while model is still training

I’m training a GPT2LMHeadModel class model using a Trainer class (both from Transformers library). I want to generate text from a prompt after every n steps of training the model. How would I do this?

Thanks!

Hi and welcome to the forum!

You can use the Trainer class in the Transformers library for custom callbacks. Here’s a simple callback to generate text after every n steps.
I haven’t tested the code below but it should provide enough guidance for you to resolve it.

import torch

# Custom Callback Definition
class GenerateTextCallback(TrainerCallback):
    def __init__(self, tokenizer: GPT2Tokenizer, prompt: str, device, n_steps=100):
        self.tokenizer = tokenizer
        self.prompt = prompt
        self.device = device
        self.n_steps = n_steps
        self.step_count = 0

    def on_step_end(self, args: TrainingArguments, state, control, **kwargs):
        self.step_count += 1
        if self.step_count % self.n_steps == 0:
            model: GPT2LMHeadModel = kwargs['model']
            input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").to(self.device)
            generated_text = model.generate(input_ids, max_length=50).cpu().tolist()[0]
            decoded_text = self.tokenizer.decode(generated_text, skip_special_tokens=True)
            print(f"\nGenerated text after {self.step_count} steps: {decoded_text}\n")

# Instantiating Callback
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
callback = GenerateTextCallback(tokenizer, "Once upon a time", device, n_steps=100)

# Setting up the Trainer class
training_args = TrainingArguments(
    output_dir="./results",
    # other training arguments...
)

trainer = Trainer(
    model=model,   # Assuming 'model' is defined elsewhere in your code
    args=training_args,
    train_dataset=train_dataset,   # Assuming 'train_dataset' is defined elsewhere in your code
    eval_dataset=eval_dataset,     # Assuming 'eval_dataset' is defined elsewhere in your code
    data_collator=data_collator,   # Assuming 'data_collator' is defined elsewhere in your code
    callbacks=[callback],
)

trainer.train()

Start training using the Trainer class and observe the console output. After every n steps, you should see the generated text from the model.

Make sure the prompt’s length plus the generated text doesn’t exceed the model’s maximum sequence length. Otherwise, you might encounter errors.
And you might want to manage the verbosity. Generating text every n steps might clutter your console if n is a small number.

3 Likes

Thank you! Something like this was what I was looking for. I will need to make a few changes, but now I get the idea. I’ll let you know if I have any issues.

1 Like