I am currently trying to train mBERT to do a binary classification on TPU.
I am connected to the TPU-VM through ssh. I have set up trainer and training args like this
def preprocess_text(df):
length = 200
tokenized_text = tokenizer(df['concat_text'], padding="max_length", max_length=length)
tokenized_text["label"] = df['post']
return tokenized_text
def prepare_dataset():
train_data = load_dataset("csv", data_files=training_file) \
.map(preprocess_text, batched=True)
test_data = load_dataset("csv", data_files=test_file) \
.map(preprocess_text, batched=True)
return train_data, test_data
def get_training_args(pre_model):
return TrainingArguments(
output_dir="model_run",
logging_dir="model_run_logs",
learning_rate=9e-6,
dataloader_num_workers=4,
do_train=True,
num_train_epochs=5,
weight_decay=0.01,
save_steps=5000,
eval_steps=5000,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
push_to_hub=False,
save_total_limit=4,
run_name=pre_model,
auto_find_batch_size=True,
tpu_num_cores=8,
tpu_metrics_debug=True
)
def run_training(pre_model):
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(pre_model)
data_collator = DefaultDataCollator()
tokenized_train, tokenized_test = prepare_dataset()
model = AutoModelForSequenceClassification.from_pretrained(
pre_model, num_labels=2, id2label=id2label, label2id=label2id, trust_remote_code=True,
problem_type="single_label_classification"
).to(device)
trainer = Trainer(
model=model,
args=get_training_args(pre_model),
train_dataset=tokenized_train["train"],
eval_dataset=tokenized_test["train"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=pre_process_logits,
)
I am setting TPU as device by doing:
os.environ["PJRT_DEVICE"]="TPU"
device = xm.xla_device()
I then run by doing
curr_model = "bert-base-multilingual-cased"
pre_model = curr_model
run_training(curr_model)
Running this fine-tuning takes forever, so I seem to do something wrong. I read that to avoid TPU compiling for each step you must set max_length for tokenizer. However, this does not seem to work as it still takes forever.
What am I doing wrong? Any help would be much appreciated.