I’m training a GPT2LMHeadModel class model using a Trainer class (both from Transformers library). I want to generate text from a prompt after every n steps of training the model. How would I do this?
Thanks!
I’m training a GPT2LMHeadModel class model using a Trainer class (both from Transformers library). I want to generate text from a prompt after every n steps of training the model. How would I do this?
Thanks!
Hi and welcome to the forum!
You can use the Trainer
class in the Transformers library for custom callbacks. Here’s a simple callback to generate text after every n
steps.
I haven’t tested the code below but it should provide enough guidance for you to resolve it.
import torch
# Custom Callback Definition
class GenerateTextCallback(TrainerCallback):
def __init__(self, tokenizer: GPT2Tokenizer, prompt: str, device, n_steps=100):
self.tokenizer = tokenizer
self.prompt = prompt
self.device = device
self.n_steps = n_steps
self.step_count = 0
def on_step_end(self, args: TrainingArguments, state, control, **kwargs):
self.step_count += 1
if self.step_count % self.n_steps == 0:
model: GPT2LMHeadModel = kwargs['model']
input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").to(self.device)
generated_text = model.generate(input_ids, max_length=50).cpu().tolist()[0]
decoded_text = self.tokenizer.decode(generated_text, skip_special_tokens=True)
print(f"\nGenerated text after {self.step_count} steps: {decoded_text}\n")
# Instantiating Callback
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
callback = GenerateTextCallback(tokenizer, "Once upon a time", device, n_steps=100)
# Setting up the Trainer class
training_args = TrainingArguments(
output_dir="./results",
# other training arguments...
)
trainer = Trainer(
model=model, # Assuming 'model' is defined elsewhere in your code
args=training_args,
train_dataset=train_dataset, # Assuming 'train_dataset' is defined elsewhere in your code
eval_dataset=eval_dataset, # Assuming 'eval_dataset' is defined elsewhere in your code
data_collator=data_collator, # Assuming 'data_collator' is defined elsewhere in your code
callbacks=[callback],
)
trainer.train()
Start training using the Trainer
class and observe the console output. After every n
steps, you should see the generated text from the model.
Make sure the prompt’s length plus the generated text doesn’t exceed the model’s maximum sequence length. Otherwise, you might encounter errors.
And you might want to manage the verbosity. Generating text every n steps might clutter your console if n is a small number.
Thank you! Something like this was what I was looking for. I will need to make a few changes, but now I get the idea. I’ll let you know if I have any issues.