I’m trying to train an ASR model using Transformers, Datasets and lightning. I want to use lightning as it will make it easier to train on my custom dataset.
I have posted the code I’m using below as well as some loss curves and WER curves.
I’ve trained on the librispeech dataset, only including utterances of 5s or less, using a batch size of 16, training for 200 epochs and a learning rate of 0.001. The loss curve is pretty much constant for the whole duration. Can anybody tell me what’s going on?
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from torchmetrics import WordErrorRate
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from datasets import load_dataset, Audio
from transformers import AutoProcessor, AutoConfig, AutoModelForCTC
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--nepochs", type=int, default=200, help="number of epochs")
parser.add_argument("--batchsize", type=int, default=16, help="batch size")
parser.add_argument("--devices", type=int, nargs='+', default=[0], help="batch size")
args = parser.parse_args()
###############################################################################################################
###############################################################################################################
### DATA
###############################################################################################################
###############################################################################################################
def createASRAudioDataset(batchsize, nworkers, processor, max_duration=5):
def prepareDataset(dataset):
dataset = dataset.select_columns(['audio', 'text'])
dataset = dataset.filter(lambda e: (len(e['audio']['array']) / e['audio']['sampling_rate']) <= max_duration)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.map(lambda e: {"text": e['text'].upper()})
dataset = dataset.map(lambda e: processor(e['audio']['array'], sampling_rate=16000, text=e['text']), num_proc=4)
dataset = dataset.remove_columns(['audio', 'text'])
return dataset
def collate(batch):
input_features = [{"input_values": b["input_values"][0]} for b in batch]
label_features = [{"input_ids": b["labels"]} for b in batch]
batch = processor.pad(input_features, padding="longest", return_tensors="pt", return_attention_mask=True)
batch_labels = processor.pad(labels=label_features, padding="longest", return_tensors="pt")
batch_labels = batch_labels["input_ids"].masked_fill(batch_labels.attention_mask.ne(1), -100)
batch["labels"] = batch_labels
return batch
datasets = load_dataset("librispeech_asr")
trainset = DataLoader(prepareDataset(datasets['train.clean.100']), batch_size=batchsize, collate_fn=collate, shuffle=True, num_workers=nworkers)
valset = DataLoader(prepareDataset(datasets['validation.clean']), batch_size=batchsize, collate_fn=collate, shuffle=False, num_workers=nworkers)
return trainset, valset
###############################################################################################################
###############################################################################################################
### Model
###############################################################################################################
###############################################################################################################
def createModel(processor):
config = AutoConfig.from_pretrained("facebook/wav2vec2-base")
config.vocab_size = len(processor.tokenizer)
config.ctc_loss_reduction = "mean"
config.pad_token_id = processor.tokenizer.pad_token_id
model = AutoModelForCTC.from_config(config)
model.gradient_checkpointing_enable()
return model
###############################################################################################################
###############################################################################################################
### Training
###############################################################################################################
###############################################################################################################
class LitModule(pl.LightningModule):
def __init__(self, processor):
super().__init__()
self.net = createModel(processor)
self.processor = processor
self.wer_train = WordErrorRate()
self.wer_eval = WordErrorRate()
def _step(self, batch, batch_index, training):
keyword = 'train' if training else 'val'
wer = self.wer_train if training else self.wer_eval
outputs = self.net(**batch)
loss = outputs.loss
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
est = torch.argmax(probs, -1)
ref = batch['labels']
ref[ref < 0] = self.processor.tokenizer.pad_token_id
est_str = self.processor.batch_decode(est)
ref_str = self.processor.batch_decode(ref)
wer(est_str, ref_str)
self.log("loss/" + keyword, loss, sync_dist=True, on_step=True, on_epoch=not training, prog_bar=True)
self.log("wer/" + keyword, wer, sync_dist=True, on_step=True, on_epoch=not training, prog_bar=True)
return loss
def training_step(self, batch, batch_index):
return self._step(batch, batch_index, True)
def validation_step(self, batch, batch_index):
return self._step(batch, batch_index, False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=args.lr, fused=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.nepochs, eta_min=1e-6)
return {
'optimizer': optimizer,
'lr_scheduler': {'scheduler': scheduler, 'interval': "epoch", "frequency": 1}
}
if __name__ == '__main__':
# Processor
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
# Data
trainset, valset = createASRAudioDataset(args.batchsize, 1, processor, max_duration=5)
# Model
litModel = LitModule(processor)
# Callbacks
callbacks = [
ModelCheckpoint(monitor='loss/val_epoch',
filename='epoch={epoch}-val_loss={loss/val_epoch:.4f}',
auto_insert_metric_name=False,
save_weights_only=True),
LearningRateMonitor(logging_interval='step')
]
# Logger
logger = TensorBoardLogger(save_dir="../runs_asr")
# Train
trainer = pl.Trainer(max_epochs=args.nepochs,
accelerator='gpu',
devices=args.devices,
gradient_clip_val=1.0,
callbacks=callbacks,
logger=logger,
sync_batchnorm=True,
)
trainer.fit(model=litModel, train_dataloaders=trainset, val_dataloaders=valset)