Training on TPU does not work

I am currently trying to train mBERT to do a binary classification on TPU.

I am connected to the TPU-VM through ssh. I have set up trainer and training args like this


def preprocess_text(df):
    length = 200
    tokenized_text = tokenizer(df['concat_text'], padding="max_length", max_length=length)
    tokenized_text["label"] = df['post']
    return tokenized_text


def prepare_dataset():
    train_data = load_dataset("csv", data_files=training_file) \
        .map(preprocess_text, batched=True)
    test_data = load_dataset("csv", data_files=test_file) \
        .map(preprocess_text, batched=True)

    return train_data, test_data

def get_training_args(pre_model):
    return TrainingArguments(
        output_dir="model_run",
        logging_dir="model_run_logs",
        learning_rate=9e-6,
        dataloader_num_workers=4,
        do_train=True,
        num_train_epochs=5,
        weight_decay=0.01,
        save_steps=5000,
        eval_steps=5000,
        evaluation_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=True,
        push_to_hub=False,
        save_total_limit=4,
        run_name=pre_model,
        auto_find_batch_size=True,
        tpu_num_cores=8,
        tpu_metrics_debug=True

    )


def run_training(pre_model):
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(pre_model)
    data_collator = DefaultDataCollator()

    tokenized_train, tokenized_test = prepare_dataset()

    model = AutoModelForSequenceClassification.from_pretrained(
        pre_model, num_labels=2, id2label=id2label, label2id=label2id, trust_remote_code=True,
        problem_type="single_label_classification"
    ).to(device)

    trainer = Trainer(
        model=model,
        args=get_training_args(pre_model),
        train_dataset=tokenized_train["train"],
        eval_dataset=tokenized_test["train"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=pre_process_logits,

    )

I am setting TPU as device by doing:

os.environ["PJRT_DEVICE"]="TPU"
device = xm.xla_device()

I then run by doing

curr_model = "bert-base-multilingual-cased"
pre_model = curr_model
run_training(curr_model)

Running this fine-tuning takes forever, so I seem to do something wrong. I read that to avoid TPU compiling for each step you must set max_length for tokenizer. However, this does not seem to work as it still takes forever.

What am I doing wrong? Any help would be much appreciated.