class GumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
<https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
"""
def __init__(self, num_codevector_groups, num_codevectors_per_group, conv_dim, codevector_dim):
super().__init__()
self.num_groups = num_codevector_groups
self.num_vars = num_codevectors_per_group
assert (
codevector_dim % self.num_groups == 0
), f"`codevector_dim {codevector_dim} must be divisible by `num_codevector_groups` {self.num_groups} for concatenation"
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.randn(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(conv_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 1
def set_temperature(self, temperature: int):
self.temperature = temperature
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states, mask_time_indices=None):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = F.gumbel_softmax(
hidden_states, tau=self.temperature, hard=True
)
# compute perplexity
codevector_soft_dist = F.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
return codevectors, perplexity
[W python_anomaly_mode.cpp:104] Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
File "train_Mel2Vec.py", line 352, in <module>
train(training_args)
File "train_Mel2Vec.py", line 318, in train
train_sequence(model, train_dataloader, eval_dataloader, optimizer, scheduler, args, device, vocab, writer, is_pretrain)
File "train_Mel2Vec.py", line 174, in train_sequence
model(input_mels.to(device), mel_lengths.to(device))
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/leecho/xi-stt/conformer/conformer/Mel2Vec.py", line 507, in forward
Quantized_features, perplexity = self.VecQuantization(extract_features, time_mask)
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/leecho/xi-stt/conformer/conformer/Mel2Vec.py", line 57, in forward
hidden_states = self.weight_proj(hidden_states)
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
output = input.matmul(weight.t())
(function _print_stack)
Traceback (most recent call last):
File "train_Mel2Vec.py", line 352, in <module>
train(training_args)
File "train_Mel2Vec.py", line 318, in train
train_sequence(model, train_dataloader, eval_dataloader, optimizer, scheduler, args, device, vocab, writer, is_pretrain)
File "train_Mel2Vec.py", line 178, in train_sequence
loss.backward()
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/leecho/.conda/envs/xi-tts/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1808, 768]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Hi, I’m trying to clone transformers wav2vec2 pretraining code.
But I got inplace operation error from Linear layer in GumbelVectorQuantizer.
I almost copied Wav2Vec2ForPretraining Document’s Quantizer except codevector’s initial contents(torch.FloatTensor -> torch.randn)
.
The error occurs when using gumbel_softmax (self.is_pretrain=True).
Quantized_features, perplexity = self.VecQuantization(extract_features.clone(), time_mask)
And as a workaround, I put extracted_feature’s .clone() to Quantizer’s input, but is it the correct solution?
Additionally, When I pretraining model by this, diversity loss almost stucked nearby 0.06 and doesn’t move.
def _compute_contrastive_loss(
self,
predicted_features: Tensor,
quantized_features: Tensor,
mask_indices_list: List[List],
):
with torch.no_grad():
batch_size, seq_len, encoder_dim = predicted_features.shape
q_mask = torch.zeros_like(predicted_features).bool()
p_mask = torch.zeros_like(predicted_features).bool()
location = []
location_base = 0
for i in range(batch_size):
mask_indices = mask_indices_list[i]
K = max(int(len(mask_indices) * self.distractors_rate), 1) # number of distractors
candidates = random.sample(mask_indices, K)
p_mask[i, candidates[0]] = True
for c in candidates:
q_mask[i, c] = True
location.append( location_base + sum(q_mask[i, :candidates[0], 0].detach()) )
location_base = location_base + K
positive = predicted_features[p_mask].view(-1, encoder_dim)
negatives = quantized_features[q_mask].view(-1, encoder_dim)
logits = F.cosine_similarity(positive.unsqueeze(1), negatives, dim=-1, eps=1e-6) / self.k_temperature
# neg_is_pos = (positive == negatives).all(-1)
# if neg_is_pos.any():
# logits[1:][neg_is_pos[1:]] = float("-inf")
logits = -torch.log(F.softmax(logits, dim=-1))
loss = torch.sum(logits[range(batch_size), location])
return loss
In fact, I suspect the cause of this error is also due to the messy contrastive loss code I wrote.