An efficient way of loading a model that was saved with torch.save

Hello, after fine-tuning a bert_model from huggingface’s transformers (specifically ‘bert-base-cased’). I can’t seem to load the model efficiently.

My model class is as following:

1. import torch
2. import torch.nn as nn
3. class Model(nn.Module):
4.   def __init__(self, model_name='bert_model'):
5.     super(Model, self).__init__()
6.     self.bert = transformers.BertModel.from_pretrained(config['MODEL_ID'], return_dict=False)
7.     self.bert_drop = nn.Dropout(0.0)
8.     self.out = nn.Linear(config['HIDDEN_SIZE'], config['NUM_LABELS'])
9.     self.model_name = model_name
10.   
11.   def forward(self, ids, mask, token_type_ids):
12.     _, o2 = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
13.     bo = self.bert_drop(o2)
14.     output = self.out(bo)
15.     return output

I then create a model, fine-tune it, and save it with the following code:

1. device = torch.device('cuda')
2. model = Model(model_name)
3. model.to(device)

4. TrainModel(model, data)  
5. torch.save(model.state_dict(), config['MODEL_SAVE_PATH']+f'{model_name}.bin')

I can load the model with this code:

model = Model(model_name=model_name)
model.load_state_dict(torch.load(model_path))

However the problem is that every time i load a model with the Model() class it installs and reads into memory a model from huggingface’s transformers due to the code line 6 in the Model() class. This is not very efficient, is there another way to load the model ?

Instead of

you can do

class Model(PreTrainedModel):

This allows you to use the built-in save and load mechanisms. Instead of torch.save you can do model.save_pretrained("your-save-dir/). After that you can load the model with Model.from_pretrained("your-save-dir/").

2 Likes

would that still allow me to stack torch layers?

Yes, you can still build your torch model as you are used to, because PreTrainedModel also subclasses nn.Module. So you get the same functionality as you had before PLUS the HuggingFace extras.

1 Like

Hi! Will using Model.from_pretrained() with the code above trigger a download of a fresh bert model?

I’m thinking of a case where for example config['MODEL_ID'] = 'bert-base-uncased', we then finetune the model and save it with save_pretrained(). When calling Model.from_pretrained(), a new object will be generated by calling __init__(), and line 6 would cause a new set of weights to be downloaded. Am I understanding correctly?

If this is the case, what would be the best way to avoid this and actually load the weights we saved?

Thanks!