Custom class for token classification

I am trying to define my own token classification model class in PyTorch
so that it can be used in a similar way to RobertaForTokenClassification
class from tranformers and that the trained model can be saved using save_pretrained
and reloaded using from_pretrained.

I thought it would be easiest to slightly modify the code of RobertForTokenClassification from tranformers.
The definition of MyModelForTokenClassification class is as follows:

class MyConfig(RobertaConfig):
def init(self, **kwargs):
self.additional_parameter = …

class MyModelForTokenClassification(RobertaPreTrainedModel):
config_class = MyConfig

def __init__(self, config, additional_data = None):
    self.additional_data = additional_data
    self.additional_parameter = config.additional_parameter
    self.num_labels = config.num_labels
    self.model = RobertaModel(config, add_pooling_layer=False)
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    self.dropout = torch.nn.Dropout(classifier_dropout)
    hidden_size = config.hidden_size
    if self.additional_data is not None:
    	hidden_size = .....
    self.classifier = torch.nn.Linear(hidden_size, config.num_labels)

def forward(.....):

In order to use AutoModelForTokenClassification, I added these 2 lines of code:

AutoConfig.register(“roberta”, MyConfig, exist_ok=True)
AutoModelForTokenClassification.register(MyConfig, MyModelForTokenClassification, exist_ok=True)

When creating a model for training:

model = AutoModelForTokenClassification.from_pretrained(“roberta-base”,

this message appears:
Some weights of MyModelForTokenClassification were not initialized from the model checkpoint
at roberta-base and are newly initialized:

It shows that the model has not been initialised with the weights from ‘robert-base’.
What else should I do to initialise MyModel with the weights from the ‘robert-base’ model?

The problem has been resolved. ‘additional_data’ should be sent to the forward method.