Hello, after fine-tuning a bert_model from huggingface’s transformers (specifically ‘bert-base-cased’). I can’t seem to load the model efficiently.
My model class is as following:
1. import torch
2. import torch.nn as nn
3. class Model(nn.Module):
4. def __init__(self, model_name='bert_model'):
5. super(Model, self).__init__()
6. self.bert = transformers.BertModel.from_pretrained(config['MODEL_ID'], return_dict=False)
7. self.bert_drop = nn.Dropout(0.0)
8. self.out = nn.Linear(config['HIDDEN_SIZE'], config['NUM_LABELS'])
9. self.model_name = model_name
10.
11. def forward(self, ids, mask, token_type_ids):
12. _, o2 = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
13. bo = self.bert_drop(o2)
14. output = self.out(bo)
15. return output
I then create a model, fine-tune it, and save it with the following code:
1. device = torch.device('cuda')
2. model = Model(model_name)
3. model.to(device)
4. TrainModel(model, data)
5. torch.save(model.state_dict(), config['MODEL_SAVE_PATH']+f'{model_name}.bin')
I can load the model with this code:
model = Model(model_name=model_name)
model.load_state_dict(torch.load(model_path))
However the problem is that every time i load a model with the Model()
class it installs and reads into memory a model from huggingface’s transformers due to the code line 6 in the Model()
class. This is not very efficient, is there another way to load the model ?