About training flax transformers: The design choice to use targets variable from external scope vs. give params as argument to loss_fn


I am confused about the design choice in transformers/examples/flax/text-classification/run_flax_glue.py at 5757923888246ea16b324f53c60ea444574005ed · huggingface/transformers · GitHub, copied below:

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics, new_dropout_rng

regarding loss_fn using the targets variable from the external context of train_step directly, while getting params as an explicit argument. Is there a specific optimization/JAX requirement/recommendation towards not using also targets as an explicit parameter, or is it an unrelated design decision, or maybe something Im not seeing?

(I guess one reason may be that we are not interested in the gradients wrt. targets, so we do not want jax.value_and_grad function to calculate those. But jax.value_and_grad also seems to have an argument to specify which gradients we are interested in. So technically this does not seem necessary to handle it this way, but maybe it is somewhat clearer design?)

Thank you very much in advance!