When I define a class based on Wav2Vec2PreTrainedModel, I can only load one pretrained model in config file.
class Wav2Vec2ForMe(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def _ctc_loss(self, logits, labels, input_values, attention_mask=None):
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
return loss
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # this is the last layer's hidden states
hidden_states = self.dropout(hidden_states)
logits_ctc = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self._ctc_loss(logits_ctc, labels[0], input_values, attention_mask)
return CausalLMOutput(
loss=loss, logits=logits_ctc
)
I can get fair wer for my dataset.
But if I want to load another pretrained model in the Wav2Vec2ForMe class by using self.ssl2 = Wav2Vec2Model(config) or self.ssl2 = Wav2Vec2Model.from_pretrained(‘wav2vec2-base’), it seems the model is not successfully loaded. Because the wer of using the new model is always 1.
How can I implement that?