Wav2vec fine-tuning with multiGPU

Here is WIP PR that makes it work with deepspeed:

wav2vec2 has 2 peculiarities:

  1. it randomly skips layers! which I think is what requires find_unused_parameters - in normal dist and also in zero-2. for zero-3 we must run all gpus in sync, so this problem is removed. (see PR)
  2. it uses weight_norm which re-creates 2 params in pre-forward which also has all kinds of potential side-effects. I am attempting to write a fused version of weight_norm + Conv1d which doesn’t use any tricks, but I haven’t fully sorted it out yet.

here is a work in progress:

import torch.nn as nn
from torch.nn.parameter import Parameter
from torch import _weight_norm, norm_except_dim
class Conv1dWithWeightNorm(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super(Conv1dWithWeightNorm, self).__init__(*args, **kwargs)
        self.dim = 2
        import deepspeed
        with deepspeed.zero.GatheredParameters(self.weight):
            weight = self.weight
        self.weight_g = Parameter(norm_except_dim(weight, 2, self.dim).data)
        self.weight_v = Parameter(weight.data)
        del self._parameters["weight"]
        self.weight = _weight_norm(self.weight_v, self.weight_g, self.dim)
        print(self.weight)

    def compute_weight(self):
        self.weight_g = Parameter(norm_except_dim(self.weight, 2, self.dim).data)
        self.weight_v = Parameter(self.weight.data)
        return _weight_norm(self.weight_v, self.weight_g, self.dim)

    def forward(self, input):
        self.weight = self.compute_weight()
        return self._conv_forward(input, self.weight, self.bias)

class Wav2Vec2PositionalConvEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv = Conv1dWithWeightNorm(
            in_channels=config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=config.num_conv_pos_embeddings,
            padding=config.num_conv_pos_embeddings // 2,
            groups=config.num_conv_pos_embedding_groups,
        )
        self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
        self.activation = ACT2FN[config.feat_extract_activation]

    def forward(self, hidden_states):
        hidden_states = hidden_states.transpose(1, 2)

        hidden_states = self.conv(hidden_states)
        hidden_states = self.padding(hidden_states)
        hidden_states = self.activation(hidden_states)

        hidden_states = hidden_states.transpose(1, 2)
        return hidden_states
3 Likes