Accessing model from a callback to predict between epochs

Hi there!

I’m writing a custom callback to do active learning using that paper [2107.14153] Semi-Supervised Active Learning with Temporal Output Discrepancy

To do so, we need, for a selected number of candidates, to predict at the end of each epoch during learning (then we add to the training dataset the ones that have the most inconsistencies).

class ActiveLearningCallback(TrainerCallback):

    def on_epoch_end(self, args, state, control, **kwargs):        
            # we use the model to predict the labels for each candidate image
            model = kwargs["model"]
            with torch.no_grad():
                logits = model(**self.inputs).logits
            # we store the predictions at the end of each epoch

Is the model accessed through kwargs["model"] the current state of the model after each epoch? It looks like I’m getting results that don’t correspond at all to what’s uploaded to the hub after training.

I saw it used on that other callback: transformers/ at 04ab5605fbb4ef207b10bf2772d88c53fc242e83 · huggingface/transformers · GitHub

Thank you!