An efficient way of loading a model that was saved with torch.save

Hello, after fine-tuning a bert_model from huggingface’s transformers (specifically ‘bert-base-cased’). I can’t seem to load the model efficiently.

My model class is as following:

1. import torch
2. import torch.nn as nn
3. class Model(nn.Module):
4.   def __init__(self, model_name='bert_model'):
5.     super(Model, self).__init__()
6.     self.bert = transformers.BertModel.from_pretrained(config['MODEL_ID'], return_dict=False)
7.     self.bert_drop = nn.Dropout(0.0)
8.     self.out = nn.Linear(config['HIDDEN_SIZE'], config['NUM_LABELS'])
9.     self.model_name = model_name
10.   
11.   def forward(self, ids, mask, token_type_ids):
12.     _, o2 = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
13.     bo = self.bert_drop(o2)
14.     output = self.out(bo)
15.     return output

I then create a model, fine-tune it, and save it with the following code:

1. device = torch.device('cuda')
2. model = Model(model_name)
3. model.to(device)

4. TrainModel(model, data)  
5. torch.save(model.state_dict(), config['MODEL_SAVE_PATH']+f'{model_name}.bin')

I can load the model with this code:

model = Model(model_name=model_name)
model.load_state_dict(torch.load(model_path))

However the problem is that every time i load a model with the Model() class it installs and reads into memory a model from huggingface’s transformers due to the code line 6 in the Model() class. This is not very efficient, is there another way to load the model ?

Instead of

you can do

class Model(PreTrainedModel):

This allows you to use the built-in save and load mechanisms. Instead of torch.save you can do model.save_pretrained("your-save-dir/). After that you can load the model with Model.from_pretrained("your-save-dir/").

1 Like

would that still allow me to stack torch layers?

Yes, you can still build your torch model as you are used to, because PreTrainedModel also subclasses nn.Module. So you get the same functionality as you had before PLUS the HuggingFace extras.

1 Like