Does anyone know if it is possible to use TRL on TPU. I am particulary interested in reinforcement learning and GRPO, as there currently does not seem to be any jax-alternatives out there.
2 Likes
TRL (Transformers Reinforcement Learning) primarily supports PyTorch, which can be challenging to run efficiently on TPUs. While JAX-based alternatives for GRPO are limited, you might try running TRL on TPU via PyTorch/XLA. However, native JAX support would require custom implementation.