Trainer API Trianing Not Happening

Hi everyone. I’m trying to train a sequence tagging task. However it seems like model is not learning. I have the following variables:

tokenized_ds = DatasetDict({                                                                                                                           
    train: Dataset({                                                                                                                                         
            features: ['id, 'tokens', 'tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],                                                    
            num_rows: 87963                                                                                                                                  
        })                                                                                                                                                   
    val: Dataset({                                                                                                                                           
            features: ['id', 'tokens', 'tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],                                                   
            num_rows: 29321                                                                                                                                  
        })                                                                                                                                                   
    test: Dataset({                                                                                                                                          
            features: ['id', 'tokens', 'tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],                                                   
            num_rows: 29322                                                                                                                                  
        })                                                                                                                                                   
    })

Here how a dataset looks like:

train_ds.features = {'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'tags': Sequence(feature=ClassLabel(names=['O', 'PERIOD', 'COMMA', 'QUESTION_MARK'], id=None), length=-1, id=None)}   

Here’s how my computation function looks like:

def compute_metrics(p):                                                                                                                                      
    predictions, labels = p                                                                                                                                  
    predictions = np.argmax(predictions, axis=2)                                                                                                             
                                                                                                                                                             
    true_labels = [                                                                                                                                          
                [label2id[label_list[l]] for (p, l) in zip(prediction, label) if l != -100]                                                                  
                for prediction, label in zip(predictions, labels)]                                                                                           
    true_predictions = [                                                                                                                                     
                [label2id[label_list[p]] for (p, l) in zip(prediction, label) if l != -100]                                                                  
                for prediction, label in zip(predictions, labels)]                                                                                           
    #                                                                                                                                                        
    print(true_labels)                                                                                                                                       
    print(true_predictions)                                                                                                                                  
                                                                                                                                                             
    true_labels_flat_torch = torch.tensor([item for sub_l in true_labels for item in sub_l])                                                                 
    true_predictions_flat_torch = torch.tensor([item for sub_l in true_predictions for item in sub_l])                                                       
                                                                                                                                                             
    f1 = multiclass_f1_score(true_labels_flat_torch, true_predictions_flat_torch, num_classes=len(label_list), average="macro")                              
    precision = multiclass_precision(true_labels_flat_torch, true_predictions_flat_torch, num_classes=len(label_list), average="macro")                      
    recall = multiclass_recall(true_labels_flat_torch, true_predictions_flat_torch, num_classes=len(label_list), average="macro")                            
    accuracy = multiclass_accuracy(true_labels_flat_torch, true_predictions_flat_torch, num_classes=len(label_list), average="macro")                        
                                                                                                                                                             
    return {                                                                                                                                                 
        "precision": precision,                                                                                                                              
        "recall": recall,                                                                                                                                    
        "f1": f1,                                                                                                                                            
        "accuracy": accuracy} 

Lastly my training arguments:

EPOCH = 3                                                                                                                                                    
TRAIN_STEPS = EPOCH * len(tokenized_punc_ds["train"])                                                                                                        
training_args = TrainingArguments(                                                                                                                           
            output_dir="./model_save/auto",                                                                                                                  
            learning_rate=2e-5,                                                                                                                              
            per_device_train_batch_size=64,                                                                                                                  
            per_device_eval_batch_size=64,                                                                                                                   
            num_train_epochs=3,                                                                                                                              
            weight_decay=0.01,                                                                                                                               
            evaluation_strategy="epoch",                                                                                                                     
            save_strategy="epoch",                                                                                                                           
            disable_tqdm=False,                                                                                                                              
            log_level="error",                                                                                                                               
            load_best_model_at_end=True)                                                                                                                     
                                                                                                                                                             
optimizer = AdamW(model.parameters(),                                                                                                                        
                  lr = 1e-3,                                                                                                                                 
                  betas = (0.9, 0.999),                                                                                                                      
                  eps = 1e-6)                                                                                                                                
                                                                                                                                                             
scheduler = get_linear_schedule_with_warmup(optimizer,                                                                                                       
                                            num_training_steps=TRAIN_STEPS,                                                                                  
                                            num_warmup_steps=0)                                                                                              
                                                                         
trainer = Trainer(                                                                                                                                           
            model=model,                                                                                                                                     
            args=training_args,                                                                                                                              
            train_dataset=tokenized_punc_ds["train"],                                                                                                        
            eval_dataset=tokenized_punc_ds["val"],                                                                                                           
            tokenizer=tokenizer,                                                                                                                             
            data_collator=data_collator,                                                                                                                     
            compute_metrics=compute_metrics,)                                                                                                                
            #optimizers=(optimizer, scheduler))                                                                                                              
                                                                                                                                                             
start = timeit.default_timer()                                                                                                                               
trainer.train()                                                                                                                                              
stop = timeit.default_timer()                                                                                                                                
print(f"Training Time: {stop-start:.2f}s")      

However model always predicts 0’s for every token. F1 scores doens’t change at any of the epochs. Model is not learning. I don’t understand what’s wrong.