Spanish ASR: Fine-Tuning Wav2Vec2

Regarding this problem, we had and impromptu debugging session in Slack with @PereLluis13 and @maxidl, and realized that the delay is being caused when grouping samples by length: the dataset is accessed sequentially in that case.

As a temporary solution, I implemented one of the ideas we discussed with @adilism: precompute the lengths of the samples and subclass Trainer to use them, if available:

# Pre-compute sample lengths
def input_lengths(example):
    example["length"] = len(example["input_values"])
    return example

# Adjust for your system
num_proc=16
common_voice_train = common_voice_train.map(input_lengths, num_proc=num_proc)

## Use subclassed Trainer class to support pre-computed lengths

from transformers import Trainer
from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler
from torch.utils.data import DataLoader
import collections

class GroupedLengthsTrainer(Trainer):
    # length_field_name should possibly be part of TrainingArguments instead
    def __init__(self, length_field_name=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.length_field_name = length_field_name
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
            self.train_dataset, collections.abc.Sized
        ):
            return None

        # Build the sampler.
        if self.args.group_by_length:
            lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
            if self.args.world_size <= 1:
                return LengthGroupedSampler(
                    self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
                )
            else:
                return DistributedLengthGroupedSampler(
                    self.train_dataset,
                    self.args.train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    lengths=lengths,
                    model_input_name=model_input_name,
                )

        else:
            return super()._get_train_sampler()

# Build trainer indicating the name of the field that contains the lengths
trainer = GroupedLengthsTrainer(
    length_field_name="length",
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

This is ugly because we are overriding a private method. In addition, a better place to indicate the field name to use for sorting would possibly be TrainingArguments, but then we’d have to subclass or wrap that one too. But it can get the job done until the issue referenced above is discussed and resolved.

6 Likes