I Got inplace operation from wav2vec2 GumbelVectorQuantizer

class GumbelVectorQuantizer(nn.Module):
    """
    Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
    <https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
    """

    def __init__(self, num_codevector_groups, num_codevectors_per_group, conv_dim, codevector_dim):
        super().__init__()
        self.num_groups = num_codevector_groups
        self.num_vars = num_codevectors_per_group

        assert (
            codevector_dim % self.num_groups == 0
        ), f"`codevector_dim {codevector_dim} must be divisible by `num_codevector_groups` {self.num_groups} for concatenation"

        # storage for codebook variables (codewords)
        self.codevectors = nn.Parameter(
            torch.randn(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
        )
        self.weight_proj = nn.Linear(conv_dim, self.num_groups * self.num_vars)

        # can be decayed for training
        self.temperature = 1

    def set_temperature(self, temperature: int):
        self.temperature = temperature

    @staticmethod
    def _compute_perplexity(probs, mask=None):
        if mask is not None:
            mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
            probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
            marginal_probs = probs.sum(dim=0) / mask.sum()
        else:
            marginal_probs = probs.mean(dim=0)

        perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
        return perplexity

    def forward(self, hidden_states, mask_time_indices=None):
        batch_size, sequence_length, hidden_size = hidden_states.shape

        # project to codevector dim
        hidden_states = self.weight_proj(hidden_states)
        hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)

        if self.training:
            # sample code vector probs via gumbel in differentiateable way
            codevector_probs = F.gumbel_softmax(
                hidden_states, tau=self.temperature, hard=True
            )

            # compute perplexity
            codevector_soft_dist = F.softmax(
                hidden_states.view(batch_size * sequence_length, self.num_groups, -1), dim=-1
            )
            perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
            
        else:
            # take argmax in non-differentiable way
            # comptute hard codevector distribution (one hot)
            codevector_idx = hidden_states.argmax(dim=-1)
            codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
                -1, codevector_idx.view(-1, 1), 1.0
            )
            codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)

            perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)

        codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
        # use probs to retrieve codevectors
        codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
        codevectors = (
            codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
            .sum(-2)
            .view(batch_size, sequence_length, -1)
        )

        return codevectors, perplexity
[W python_anomaly_mode.cpp:104] Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
  File "train_Mel2Vec.py", line 352, in <module>
    train(training_args)
  File "train_Mel2Vec.py", line 318, in train
    train_sequence(model, train_dataloader, eval_dataloader, optimizer, scheduler, args, device, vocab, writer, is_pretrain)
  File "train_Mel2Vec.py", line 174, in train_sequence
    model(input_mels.to(device), mel_lengths.to(device))
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/leecho/xi-stt/conformer/conformer/Mel2Vec.py", line 507, in forward
    Quantized_features, perplexity = self.VecQuantization(extract_features, time_mask)
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/leecho/xi-stt/conformer/conformer/Mel2Vec.py", line 57, in forward
    hidden_states = self.weight_proj(hidden_states)
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
    output = input.matmul(weight.t())
 (function _print_stack)
Traceback (most recent call last):
  File "train_Mel2Vec.py", line 352, in <module>
    train(training_args)
  File "train_Mel2Vec.py", line 318, in train
    train_sequence(model, train_dataloader, eval_dataloader, optimizer, scheduler, args, device, vocab, writer, is_pretrain)
  File "train_Mel2Vec.py", line 178, in train_sequence
    loss.backward()
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1808, 768]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Hi, I’m trying to clone transformers wav2vec2 pretraining code.
But I got inplace operation error from Linear layer in GumbelVectorQuantizer.
I almost copied Wav2Vec2ForPretraining Document’s Quantizer except codevector’s initial contents(torch.FloatTensor -> torch.randn).

The error occurs when using gumbel_softmax (self.is_pretrain=True).

Quantized_features, perplexity = self.VecQuantization(extract_features.clone(), time_mask)

And as a workaround, I put extracted_feature’s .clone() to Quantizer’s input, but is it the correct solution?

Additionally, When I pretraining model by this, diversity loss almost stucked nearby 0.06 and doesn’t move.

    def _compute_contrastive_loss(
        self,
        predicted_features: Tensor, 
        quantized_features: Tensor, 
        mask_indices_list: List[List],
    ):
        with torch.no_grad():
            batch_size, seq_len, encoder_dim = predicted_features.shape

            q_mask = torch.zeros_like(predicted_features).bool()
            p_mask = torch.zeros_like(predicted_features).bool()
            location = []
            location_base = 0
            for i in range(batch_size):
                mask_indices = mask_indices_list[i]

                K = max(int(len(mask_indices) * self.distractors_rate), 1) # number of distractors
                candidates = random.sample(mask_indices, K)
                
                p_mask[i, candidates[0]] = True

                for c in candidates:
                    q_mask[i, c] = True
                    
                location.append( location_base + sum(q_mask[i, :candidates[0], 0].detach()) )
                location_base = location_base + K
    
        positive = predicted_features[p_mask].view(-1, encoder_dim)
        negatives = quantized_features[q_mask].view(-1, encoder_dim)
        
        logits = F.cosine_similarity(positive.unsqueeze(1), negatives, dim=-1, eps=1e-6) / self.k_temperature
        
#         neg_is_pos = (positive == negatives).all(-1)
#         if neg_is_pos.any():
#             logits[1:][neg_is_pos[1:]] = float("-inf")
                
        logits = -torch.log(F.softmax(logits, dim=-1))

        loss = torch.sum(logits[range(batch_size), location])

        return loss

In fact, I suspect the cause of this error is also due to the messy contrastive loss code I wrote.