The XLA environment is pretty hellish, so I’m not even sure if the following is correct…
What your symptoms strongly suggest (Trainium / XLA context)
On Trn1, training runs on XLA (via Neuron + torch_xla), and Optimum’s NeuronTrainer is not feature-equivalent to the GPU Trainer. Two consequences matter most in your case:
- Evaluation is intentionally disabled in
NeuronTrainer (it throws). (GitHub)
- Checkpointing is a synchronization-sensitive operation on XLA, and certain save sequences can cause hangs or even tensor corruption unless steps are marked/flushed correctly. (awsdocs-neuron.readthedocs-hosted.com)
Your “training stops learning right after the first checkpoint” is exactly the kind of “save boundary” failure XLA users run into when execution diverges across ranks or when saving interacts badly with lazy execution.
1) Evaluation: why you’re blocked, and what “supported” looks like on Trn1
Why the error happens
NeuronTrainer raises: “Evaluation is not supported in NeuronTrainer.” (GitHub)
So yes: with NeuronTrainer, you cannot do the usual evaluation_strategy="epoch"/"steps", compute_metrics, eval_dataset, etc.
The confusing part is that NeuronTrainingArguments still exposes many eval-related knobs (e.g., do_eval, eval_strategy, eval_steps) in the docs, but the trainer itself blocks the eval path. (Hugging Face)
Practical workaround that matches how Optimum-Neuron is designed
Do evaluation out-of-band:
- Save checkpoints periodically during training
- Consolidate shards if needed (common with Neuron distributed)
- Evaluate in a separate process (on a GPU box/local, or as a separate job on AWS)
This is the standard operational pattern when the training runtime can’t cheaply/robustly run eval inside the same compiled execution.
If you must evaluate on the Trn1 machine during training
You typically switch to a non-NeuronTrainer approach (plain HF Trainer on torch-neuronx/torch-xla). That route supports do_eval, but you inherit stricter “XLA discipline” (static shapes, compilation behavior, etc.). (This is a different training stack than Optimum’s NeuronTrainer.)
2) “save_strategy='epoch' saves nothing”: the most likely causes
Likely cause A: you’re expecting “GPU-style checkpoints,” but you’re getting sharded output / rank-scoped writes
On Neuron distributed setups, checkpoints can be sharded and not look like a single pytorch_model.bin. Optimum-Neuron’s distributed training guide describes shard-based layouts and how ZeRO-1 behaves. (Hugging Face)
Also, in NeuronX Distributed, not every rank writes the same thing (e.g., some states are written on DP rank 0 depending on configuration). (awsdocs-neuron.readthedocs-hosted.com)
If you’re inspecting the directory from a different rank/context, it can look like “nothing saved.”
Likely cause B: epoch boundaries are not being reached in the way you think
If you set max_steps (or your dataloader is effectively step-driven), the trainer can behave as “step-based,” and epoch-end callbacks may not trigger as expected.
Likely cause C: relative output_dir="./training" + multi-process launch
Relative paths are a common footgun when launching multi-process jobs. Even when it “works,” you can end up with output written somewhere other than where you’re looking. (This is not Neuron-specific, but it shows up more often on Trn1 runs because you usually launch distributed.)
What I would do immediately: use an absolute output directory on a mounted volume you know persists (e.g., EBS), and verify the directory contents from the primary process.
3) save_strategy='steps' + rank/device error + zero_1 + “grad_norm becomes 0”
This is the core of your case. Here are the two most plausible, high-signal explanations.
Explanation 1: XLA save boundary causes replica divergence or state corruption
Two related XLA pitfalls are documented:
- Saving without a proper “step mark” can corrupt tensors in certain
xm.save() sequences due to parameter aliasing. AWS explicitly recommends calling xm.mark_step() before xm.save() to avoid this class of issue. (awsdocs-neuron.readthedocs-hosted.com)
- If only the “master” replica runs checkpointing code while others proceed differently, execution can diverge and hang (or behave incorrectly) because XLA execution is lazy/graph-based. The PyTorch/XLA issue explicitly illustrates the save order problem and why checkpointing must be coordinated. (GitHub)
Why this matches your symptom: you see normal learning until the first checkpoint save, then grad_norm collapses to 0 and training stops improving. That is a classic “something went wrong at the checkpoint boundary” signal on XLA.
Explanation 2: you’re interpreting grad_norm from a rank/stage where it becomes meaningless after the first save
Optimum-Neuron’s trainer tracks grad_norm (and in some distributed/pipeline setups, what gets logged per process can be misleading). (GitHub)
This is less likely given your statement that the model “stops learning,” but it’s worth verifying by looking at loss and/or evaluating the saved checkpoint externally.
Where zero_1 fits in (and why toggling it changes the failure mode)
zero_1 is ZeRO-1 optimizer-state sharding. Optimum’s own docs describe ZeRO-1 and when it’s beneficial. (Hugging Face)
Checkpointing with ZeRO-1 is a known sharp edge in Neuron stacks historically; NeuronX Distributed release notes explicitly mention “Fixed an issue with Zero1 checkpoint saving/loading” and also note that checkpointing is sharded and needs combining. (awsdocs-neuron.readthedocs-hosted.com)
So, a very practical interpretation of your experience is:
- With
zero_1=True, you hit a bug / incompatibility in the checkpointing path (rank/device empty error).
- With
zero_1=False, you avoid that path, but you still hit an XLA save boundary problem that knocks training off course after the first save (because saving is still happening, just via a different route).
This is consistent with the fact that Optimum-Neuron has ongoing trainer refactors and training-related fixes across releases. (GitHub)
What I would do for your run (a concrete plan)
Step 0: Make the run observable
Even if you can’t evaluate in-trainer, you can still avoid “blind training”:
- Save checkpoints periodically
- Evaluate those checkpoints externally (GPU/local)
- Track a simple scalar like training loss + learning rate over time
Step 1: Switch to the “least fragile” checkpointing mode first
Use step-based saving, and initially simplify what you save:
save_strategy="steps"
save_steps = something moderate (e.g., every few hundred optimizer steps)
save_only_model=True (Optimum-Neuron exposes this) (Hugging Face)
This removes optimizer/scheduler/RNG state from the checkpoint, which reduces the ZeRO-1 and XLA serialization surface area. It’s ideal if your goal is “evaluate during training,” not “resume training exactly from checkpoint.”
Step 2: Keep XLA serialization settings conservative
Optimum-Neuron documents defaults like use_xser=True and async_save=False in the trainer args. (Hugging Face)
Stick with:
use_xser=True
async_save=False
because async saving increases complexity/memory pressure and can make save boundaries harder to reason about.
Step 3: Address the XLA save boundary hazard
Because AWS has a documented “xm.save() sequence can corrupt tensors” note, I would treat this as a version-sensitive issue and do one of:
- Upgrade/downgrade to a Neuron stack version where the
xm.save issue is mitigated (or where Optimum has incorporated the safe sequence), guided by your AMI’s supported versions.
- Ensure the code path doing the save includes the “mark step before save” discipline that AWS recommends. (awsdocs-neuron.readthedocs-hosted.com)
You may not control Optimum’s internal save calls directly, so in practice this often means: move to a newer Optimum-Neuron + torch-neuronx/torch-xla combo (or a known-good pinned set).
Step 4: Re-introduce zero_1=True once model-only saving is stable
Given that NxD release notes mention a Zero1 checkpoint fix, I would:
If that’s stable, and only then consider saving optimizer state (if you truly need resume).
Step 5: External evaluation loop (replaces in-training eval)
Since NeuronTrainer can’t evaluate, you do:
- Train with checkpoint saves
- After each checkpoint: run a separate evaluation script that loads the checkpoint and runs metrics
- Log results (even a CSV is enough)
This gives you “eval every epoch/steps” operationally, without relying on unsupported trainer paths. (GitHub)
“Why does the final model save work, but intermediate saves break training?”
This is common on XLA systems: end-of-training saving typically happens after a controlled barrier/teardown, while mid-loop saving happens inside the compiled/lazy execution flow. If that mid-loop save causes:
- divergence across replicas, or
- a bad save sequence that corrupts live tensors,
you can see exactly your pattern: first checkpoint triggers the failure, everything after is flat.
AWS’s own warning about xm.save + the XLA checkpointing discussion are directly relevant to this “mid-loop save boundary” problem. (awsdocs-neuron.readthedocs-hosted.com)
Minimal “target state” for your situation
If your priority is: stable training + periodic evaluation, the most robust approach on Trn1 today is usually:
- Train with
NeuronTrainer (no eval)
- Save model-only checkpoints by steps
- Evaluate externally
- Only attempt full “resume-able” checkpoints (optimizer/scheduler) once the above is stable, and only with a version set where ZeRO-1 checkpointing is known-good (awsdocs-neuron.readthedocs-hosted.com)
If you want the fastest pinpoint fix
The two most diagnostic pieces of info (no back-and-forth needed) are:
- versions:
optimum-neuron, neuronx-distributed, torch-neuronx, torch-xla
- the exact checkpoint-save log lines around the first save (plus whether you launched with
torchrun and how many ranks)
With that, the likely resolution becomes either: