Can't get Wav2Vec to converge

I’m trying to train an ASR model using Transformers, Datasets and lightning. I want to use lightning as it will make it easier to train on my custom dataset.

I have posted the code I’m using below as well as some loss curves and WER curves.
I’ve trained on the librispeech dataset, only including utterances of 5s or less, using a batch size of 16, training for 200 epochs and a learning rate of 0.001. The loss curve is pretty much constant for the whole duration. Can anybody tell me what’s going on?

import  argparse
import  torch
from    torch.utils.data    import Dataset, DataLoader
from    torchmetrics        import WordErrorRate
import  lightning.pytorch   as pl
from    lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from    lightning.pytorch.loggers import TensorBoardLogger
from    datasets            import load_dataset, Audio
from    transformers        import AutoProcessor, AutoConfig, AutoModelForCTC

parser = argparse.ArgumentParser()
parser.add_argument("--lr",         type=float, default=0.001,  help="Learning rate")
parser.add_argument("--nepochs",    type=int,   default=200,    help="number of epochs")
parser.add_argument("--batchsize",  type=int,   default=16,     help="batch size")
parser.add_argument("--devices",    type=int,   nargs='+', default=[0],  help="batch size")
args = parser.parse_args()

###############################################################################################################
###############################################################################################################
### DATA
###############################################################################################################
###############################################################################################################

def createASRAudioDataset(batchsize, nworkers, processor, max_duration=5):
    def prepareDataset(dataset):
        dataset = dataset.select_columns(['audio', 'text'])
        dataset = dataset.filter(lambda e: (len(e['audio']['array']) / e['audio']['sampling_rate']) <= max_duration)
        dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
        dataset = dataset.map(lambda e: {"text": e['text'].upper()})
        dataset = dataset.map(lambda e: processor(e['audio']['array'], sampling_rate=16000, text=e['text']), num_proc=4)
        dataset = dataset.remove_columns(['audio', 'text'])
        return dataset

    def collate(batch):
        input_features  = [{"input_values": b["input_values"][0]} for b in batch]
        label_features  = [{"input_ids": b["labels"]} for b in batch]

        batch           = processor.pad(input_features, padding="longest", return_tensors="pt", return_attention_mask=True)
        batch_labels    = processor.pad(labels=label_features, padding="longest", return_tensors="pt")
        batch_labels    = batch_labels["input_ids"].masked_fill(batch_labels.attention_mask.ne(1), -100)
        batch["labels"] = batch_labels

        return batch

    datasets = load_dataset("librispeech_asr")
    trainset = DataLoader(prepareDataset(datasets['train.clean.100']), batch_size=batchsize, collate_fn=collate, shuffle=True, num_workers=nworkers)
    valset   = DataLoader(prepareDataset(datasets['validation.clean']), batch_size=batchsize, collate_fn=collate, shuffle=False, num_workers=nworkers)
    return trainset, valset

###############################################################################################################
###############################################################################################################
### Model
###############################################################################################################
###############################################################################################################

def createModel(processor):
    config = AutoConfig.from_pretrained("facebook/wav2vec2-base")
    config.vocab_size           = len(processor.tokenizer)
    config.ctc_loss_reduction   = "mean"
    config.pad_token_id         = processor.tokenizer.pad_token_id
    model                       = AutoModelForCTC.from_config(config)
    model.gradient_checkpointing_enable()
    return model


###############################################################################################################
###############################################################################################################
### Training
###############################################################################################################
###############################################################################################################

class LitModule(pl.LightningModule):
    def __init__(self, processor):
        super().__init__()
    
        self.net        = createModel(processor)
        self.processor  = processor
        self.wer_train  = WordErrorRate()
        self.wer_eval   = WordErrorRate()

    def _step(self, batch, batch_index, training):
        keyword = 'train' if training else 'val'
        wer     = self.wer_train if training else self.wer_eval

        outputs = self.net(**batch)
    
        loss   = outputs.loss
        logits = outputs.logits
        probs  = torch.softmax(logits, dim=-1)
        est    = torch.argmax(probs, -1)
        ref    = batch['labels']
        ref[ref < 0] = self.processor.tokenizer.pad_token_id

        est_str = self.processor.batch_decode(est)
        ref_str = self.processor.batch_decode(ref)
        wer(est_str, ref_str)

        self.log("loss/" + keyword, loss, sync_dist=True, on_step=True, on_epoch=not training, prog_bar=True)
        self.log("wer/"  + keyword, wer, sync_dist=True, on_step=True, on_epoch=not training, prog_bar=True)
        return loss

    def training_step(self, batch, batch_index):
        return self._step(batch, batch_index, True)
    
    def validation_step(self, batch, batch_index):
        return self._step(batch, batch_index, False)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=args.lr, fused=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.nepochs, eta_min=1e-6)
        return {
            'optimizer': optimizer, 
            'lr_scheduler': {'scheduler': scheduler, 'interval': "epoch", "frequency": 1}
        }
    
if __name__ == '__main__':

    # Processor
    processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")

    # Data
    trainset, valset = createASRAudioDataset(args.batchsize, 1, processor, max_duration=5)

    # Model
    litModel = LitModule(processor)

    # Callbacks
    callbacks = [
        ModelCheckpoint(monitor='loss/val_epoch', 
                        filename='epoch={epoch}-val_loss={loss/val_epoch:.4f}',
                        auto_insert_metric_name=False,
                        save_weights_only=True),
        LearningRateMonitor(logging_interval='step')
    ]

    # Logger
    logger = TensorBoardLogger(save_dir="../runs_asr")
    
    # Train
    trainer = pl.Trainer(max_epochs=args.nepochs, 
                         accelerator='gpu', 
                         devices=args.devices, 
                         gradient_clip_val=1.0,
                         callbacks=callbacks,
                         logger=logger,
                         sync_batchnorm=True,
                         )

    trainer.fit(model=litModel, train_dataloaders=trainset, val_dataloaders=valset)

I used LearningRateFinder and it came back with an optimal learning rate of 6.3095734448019305e-06.
It definitely feels like there is an error somewhere.

Wav2Vec2 WER remains 1.00 and return blank transcriptions looks related

Anybody know what kind parameters you need to successfully train Wav2Vec like:

  • Dataset size
  • Utterance length (I’m using 5 seconds. Maybe you need much longer. I’m bound by my 12GB 1080Ti)
  • Learning rate
  • Optimizer (I’m using AdamW)
  • Scheduler (I’m using CosineAnnealing)
  • Number of epochs (200)
  • Tokenizer (I’m using facebook/wav2vec2-base configuration, which seems to be a character-based tokenizer. Maybe that’s way too simple)
  • etc