Subclassing a pretrained model for a new objective

I would like to use a pretrained model as an encoder for a new task. It is essentially multiple sequence classification objectives, like in the ...ForSequenceClassification models, but with an output layer for each subtask.

I could just create wrappers around the encoder, but I’d like to subclass from PreTrainedModel to better integrate with the Trainer class. How exactly should I do? Do I need to create a config class as well? I will at least need to supply an extra list or dict to the config telling how many classes each subtask has.

Thanks!

1 Like

You can definitely subclass PretrainedConfig for your custom config and PreTrainedModel for your custom model, then access all the methods of the library.

@sgugger thanks! But in that case what is needed to make methods like from_pretrained work out of the box? I saw that the pretrained model classes have a class attribute called config_class, is setting that enough?

It’s to find the right config in the Transformers library. In your case, you might have to use two steps:

config = CustomConfig.from_pretrained(path_to_folder_with_config_and_weights)
model = CustomModel.from_pretrained(path_to_folder_with_config_and_weights, config)

Ok. But how can I load the pretrained model (i.e., the encoder inside my class)?
I tried doing CustomModel.from_pretrained(path_to_pretrained, additional_config_data), but that ignored all the weights in the checkpoint (name mismatches, I suppose?).

Did you save the corresponding model with save_pretrained?

Nope, I haven’t even fine tuned the model yet :slight_smile:
I’m calling from_pretrained in the encoder directly, after creating the classifier object and before training, but that looks hacky.

I’m not sure what you want to do, but calling from_pretrained on your class with weights saved from another class will not work. If you want to use a pretrained model for part of the your custom model, you should use the from_pretrained method when defining that part of your custom model.