opened 08:53AM - 13 Mar 25 UTC
🐛 bug
⚡accelerate
🏋 ORPO
### Reproduction
Use the example script in repo with its sample commands but wi…th accelerate.
Latest commit: [4871c82](https://github.com/huggingface/trl/commit/4871c82b0cd1caae72522182f9171ea069481250)
```bash
accelerate launch examples/scripts/orpo.py --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 8e-6 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="gpt2-aligned-orpo" --warmup_steps 150 --bf16 --logging_first_step --no_remove_unused_columns --log_level detail
```
The training would hang on step 0.
I have tracked it down to these two lines https://github.com/huggingface/trl/blob/4871c82b0cd1caae72522182f9171ea069481250/trl/trainer/orpo_trainer.py#L847-L850
These are really large tensors that are being propagated.
```
torch.Size([4, 223, 50257])
torch.Size([4, 223, 50257])
```
Solution: Call `.detach().mean()` on them prior to gather.
Happy to make the PR if we decide to average them prior to broadcast or sum like in KTOTrainer https://github.com/huggingface/trl/blob/4871c82b0cd1caae72522182f9171ea069481250/trl/trainer/kto_trainer.py#L1270-L1272
---
Unrelated note: these two lines below should also be `.detach()` as I noticed they have graph on them.
https://github.com/huggingface/trl/blob/4871c82b0cd1caae72522182f9171ea069481250/trl/trainer/orpo_trainer.py#L852-L853
Credit to `morphism` in discord who helped track down root PR cause and provided hints.
### System Info
- Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
- Python version: 3.11.11
- TRL version: 0.16.0.dev0+4871c82
- PyTorch version: 2.5.1+cu124
- CUDA device(s): NVIDIA A40, NVIDIA A40
- Transformers version: 4.49.0
- Accelerate version: 1.3.0
- Accelerate config: not found
- Datasets version: 3.2.0
- HF Hub version: 0.28.1
- bitsandbytes version: 0.45.2
- DeepSpeed version: 0.16.1
- Diffusers version: not installed
- Liger-Kernel version: 0.5.3
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: 0.14.0
- vLLM version: not installed
### Checklist
- [x] I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))
- [x] I have included my system information
- [x] Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))
- [x] Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))
- [x] Any traceback provided is complete