How to use AdaFactor on TPU?

I am trying to use AdaFactor and linear_scheduler_with_warmup for finetuning T5. The training loss keeps on changing, but doesn’t decrease. The validation loss stays constant. The lr changes correctly, I was able to see that from comet graph. I am able to train the model on GPU. I am using the default values from the HF docs:

model_config  = AutoConfig.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", config=model_config)
model.train()

WRAPPED_MODEL = xmp.MpModelWrapper(model)

Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)

lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                               num_training_steps=Config.total_train_steps,
                                               num_warmup_steps=Config.warmup_steps)

def _mp_fn_boiler_plate(index, Config=Config):
    device = xm.xla_device()

    model = WRAPPED_MODEL.to(device)

    print("Loading datasets... ", end="")

    training_args = TrainingArguments(
        output_dir=os.path.join(os.curdir, "results"),
        num_train_epochs=1,
        evaluation_strategy="steps",
        weight_decay=0.0,
        logging_dir=os.path.join(os.curdir, "log"),
        eval_steps=Config.eval_steps,
        logging_steps=Config.logging_steps,
        per_device_train_batch_size=Config.train_batch_size,
        per_device_eval_batch_size=Config.valid_batch_size,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics if Config.compute_metrics else None,
        data_collator=data_collator,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        optimizers=(optimizer, lr_scheduler),
        
    )
    trainer.place_model_on_device = False
    trainer.train()

_mp_fn = partial(_mp_fn_boiler_plate, Config=Config)
xmp.spawn(_mp_fn, start_method="fork")