This is the custom class:
class BERTMLM(nn.Module):
def __init__(self, language_model, topics, input_channels=64, params=default_bert_params,
save_pretrained_model=False, threshold=None,
lr_method: str = 'linear'):
super(BERTMLM, self).__init__()
if os.path.exists(DIR + '/../' + language_model):
# use manually downloaded (with model.save_pretrained())
lang_model = DIR + '/../' + language_model
else:
lang_model = language_model
if language_model == 'bert-base-multilingual-cased':
self.bert_model = BertModel.from_pretrained(lang_model)
self.output_vec_size = 768
elif language_model == 'bert-base-cased':
self.bert_model = BertModel.from_pretrained(lang_model)
self.output_vec_size = 768
elif language_model == 'bert-base-german-cased':
self.bert_model = BertModel.from_pretrained(lang_model)
self.output_vec_size = 768
elif language_model == 'xlm-roberta-base':
self.bert_model = XLMRobertaModel.from_pretrained(lang_model)
self.bert_config = XLMRobertaConfig.from_pretrained(lang_model)
elif language_model == 'xlm-roberta-large':
self.bert_model = XLMRobertaModel.from_pretrained(lang_model)
self.bert_config = XLMRobertaConfig.from_pretrained(lang_model)
self.output_vec_size = 1024
else:
raise NotImplementedError()
if not os.path.exists(DIR + '/../' + language_model) and save_pretrained_model:
self.bert_model.save_pretrained(DIR + '/../' + language_model)
self.threshold = threshold
self.lr = params['lr']
self.max_epochs_without_loss_reduction = params['max_epochs_without_loss_reduction']
self.epochs = params['epochs']
self.params = params
self.lr_method = lr_method
self.batch_size = params['batch_size']
self.vocab_size = self.bert_config.vocab_size
# Final Layer - LM head to predict Masked token
self.fc_layer = nn.Linear(self.bert_config.hidden_size, self.bert_config.hidden_size)
self.relu = nn.ReLU()
self.layer_norm = nn.LayerNorm(self.bert_config.hidden_size, eps=1e-8)
self.decoder = nn.Linear(self.bert_config.hidden_size, self.bert_config.vocab_size)
def forward(self, b_input_ids, b_input_mask):
out = self.bert_model(b_input_ids, attention_mask=b_input_mask, return_dict=False)[0]
# Final Layer - LM head
out = self.fc_layer(out)
out = self.relu(out)
out = self.layer_norm(out)
out = self.decoder(out)
return out