Hello,
I am confused about the design choice in transformers/examples/flax/text-classification/run_flax_glue.py at 5757923888246ea16b324f53c60ea444574005ed · huggingface/transformers · GitHub, copied below:
# define step functions
def train_step(
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels")
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics, new_dropout_rng
regarding loss_fn
using the targets
variable from the external context of train_step
directly, while getting params
as an explicit argument. Is there a specific optimization/JAX requirement/recommendation towards not using also targets
as an explicit parameter, or is it an unrelated design decision, or maybe something Im not seeing?
(I guess one reason may be that we are not interested in the gradients wrt. targets
, so we do not want jax.value_and_grad
function to calculate those. But jax.value_and_grad
also seems to have an argument to specify which gradients we are interested in. So technically this does not seem necessary to handle it this way, but maybe it is somewhat clearer design?)
Thank you very much in advance!
Hande