How to show the learning rate during training

Hi everyone :slight_smile:

I would like to know if it is possible to include the learning rate value as part of the information presented during the training.

The columns Accuracy, F1, Precision and Recall were added after setting a custom compute_metrics function. And I would like to have the Learning Rate as well.

Is it possible to add it there?

Thanks in advance :slight_smile:

Hi Milyiyo,

I am curious how did you add accuracy, F1 Precision and Recall?

Hi @Amalq,

I used this:

from sklearn.metrics import precision_recall_fscore_support, accuracy_score


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

That function was provided to the Trainer like this:

from transformers import Trainer

trainer = Trainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  tokenizer=tokenizer,
                  eval_dataset=dataset['validation'],
                  train_dataset=dataset['train'],
                  data_collator=data_collator)

I hope this can help :slight_smile:

2 Likes

Thank you so much @milyiyo , that is really helpful.

1 Like

Hi Alberto, yes it is possible to include learning rate in the evaluation logs!

Fortunately, the log() method of the Trainer class is one of the methods that you can “subclass” to inject custom behaviour: Trainer

So, all you have to do is create your own Trainer subclass and override the log() method like so:

class MyTrainer(Trainer):
    def log(self, logs: Dict[str, float]) -> None:
        logs["learning_rate"] = self._get_learning_rate()
        super().log(logs)

trainer = MyTrainer(...)
trainer.train()

You should now see the learning rate in the eval logs.

Hope that helps, let me know if any questions.

Cheers
Heiko

2 Likes

Thanks @marshmellow77 :slight_smile:

That solved my problem, look the new column:

Thanks a lot for your help :hugs:

1 Like

Hi @marshmellow77 ,
I am trying to use the same logic to log the learning rates at each epoch. But, I am getting the following error:
AttributeError: ‘LMCL_Trainer’ object has no attribute ‘_get_learning_rate’

LMCL_Trainer is the name of the subclass extended from Trainer. Providing the class below →

class LMCL_Trainer(Trainer):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    
    def log(self, logs: Dict[str, float]) -> None:
        logs["learning_rate"] = self._get_learning_rate()
        super().log(logs)

I have the transformers version 3.3.0 installed.

Hi @Vighnesh , the transformers library is currently at version 4.19. Lots has changed since v3.3, I recommend upgrading to the latest version to see if this solution works for you.

Cheers
Heiko

Thanks a lot. Yes the latest version has a lot of changes/improvements.

The solution worked with transformers version 4.19

1 Like

Hi :slightly_smiling_face:

I’ve tried the suggested solution with transformers version: 4.31.0, but for some reason the learning rate is not plotted as part of the evaluation logs.

Any ideas why? (I do see that it is saved in the log history)

@marshmellow77 Hello. I want to use your solution to infer the learning rate. I am currently using Python 3.10. I get the following error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[36], line 1
----> 1 class MyTrainer(Trainer):
      2     def log(self, logs: Dict[str, float]) -> None:
      3         logs["learning_rate"] = self._get_learning_rate()

Cell In[36], line 2, in MyTrainer()
      1 class MyTrainer(Trainer):
----> 2     def log(self, logs: Dict[str, float]) -> None:
      3         logs["learning_rate"] = self._get_learning_rate()
      4         print(logs["learning_rate"])

NameError: name 'Dict' is not defined

Which specific module do you need to import it from?

Hi @artyomboyko,

Did you solve the error?

Maybe was missing something like: from typing import Dict

@milyiyo Hello and Happy New year! Yes, you can find solution here.