Training Loop Freezing on Google Colab

Was using a copy of the script seen in this doc page, altering it to use it with a TPU and Goolge Colab. This runs well for a time however when it gets to the main training loops it runs for a time before completely freezing and no longer responding to interrupts, leaving me no option but to restart the session.

The part that confuses me is that it seems to always be on the 619th time through the loop, whether with a batch size of 1, or 10.

Hoping this is a mistake I’ve made, however have no idea how to confirm and can’t find anyone with similar problems.

CODE:

mainDataset = datasets.load_dataset("csv",data_files=["steamCleaned_train.csv"],sep=";")

tokenizedDataset = mainDataset.map(tokenize_func, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenizedDataset = tokenizedDataset.remove_columns(["sentence"])
tokenizedDataset.set_format("torch")

train_dataloader = DataLoader(
    tokenizedDataset["train"], shuffle=True, batch_size=2, collate_fn=data_collator
)
model = AutoModelForSequenceClassification.from_pretrained(baseModel)
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)
print("DONE PREPROCESSING")

optimizer = AdamW(model.parameters(), lr=5e-5)


from transformers import get_scheduler

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)


model.to(device)


progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

print("DONE")

Managed to figure out the issue in the end - was not an issue with the code, but my dataset. The dataset is a production dataset but was lacking data with a certain label which caused the weights/biases to quickly converge to a value of ~1e-22 (at around 600 iterations).

From what I understand, TPUs lack floating point precision and it was this that caused the freezing. Used an updated dataset and it worked perfectly.