Using GRPOTrainer with a custom PyTorch module?

Hello! I was wondering what changes I would need to make to use GRPOTrainer with a custom PyTorch module class.

Currently I have an nn.Module subclass that wraps around an existing Huggingface transformer, except with a custom forward and generate function.

I was wondering though if there were resources on either converting an nn.Module to a transformer to be used with Trainer, or what other functionality I would need to implement as well as changes I’d need to make to my forward and generate methods to work with

1 Like

If you inherit PreTrainedModel, you should have most of the necessary functions. As for Trainer, it seems that you can modify loss functions, gradient-related functions, and so on.

Sorry for bumping but do you know if Trainer supports passing in inputs_embeds for generate yet? I ended up just monkey-patching the generate method to use my own generation code, but it seems that when I directly pass in inputs_embeds instead of input_ids into the original generate method, I get the following error:

Traceback (most recent call last):
  File "test_grpo.py", line 124, in <module>
    trainer.train()
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 3692, in training_step
    inputs = self._prepare_inputs(inputs)
  File "env/lib64/python3.9/site-packages/trl/trainer/grpo_trainer.py", line 576, in _prepare_inputs
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
IndexError: argmax(): Expected reduction dim 1 to have non-zero size.
Traceback (most recent call last):
  File "test_grpo.py", line 124, in <module>
    trainer.train()
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 3692, in training_step
    inputs = self._prepare_inputs(inputs)
  File "env/lib64/python3.9/site-packages/trl/trainer/grpo_trainer.py", line 576, in _prepare_inputs
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
IndexError: argmax(): Expected reduction dim 1 to have non-zero size.
1 Like

fixed. turns out i just need to pad back the output with the original prompt lol

1 Like