Regarding this problem, we had and impromptu debugging session in Slack with @PereLluis13 and @maxidl, and realized that the delay is being caused when grouping samples by length: the dataset is accessed sequentially in that case.
As a temporary solution, I implemented one of the ideas we discussed with @adilism: precompute the lengths of the samples and subclass Trainer
to use them, if available:
# Pre-compute sample lengths
def input_lengths(example):
example["length"] = len(example["input_values"])
return example
# Adjust for your system
num_proc=16
common_voice_train = common_voice_train.map(input_lengths, num_proc=num_proc)
## Use subclassed Trainer class to support pre-computed lengths
from transformers import Trainer
from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler
from torch.utils.data import DataLoader
import collections
class GroupedLengthsTrainer(Trainer):
# length_field_name should possibly be part of TrainingArguments instead
def __init__(self, length_field_name=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.length_field_name = length_field_name
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
return None
# Build the sampler.
if self.args.group_by_length:
lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
)
else:
return DistributedLengthGroupedSampler(
self.train_dataset,
self.args.train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return super()._get_train_sampler()
# Build trainer indicating the name of the field that contains the lengths
trainer = GroupedLengthsTrainer(
length_field_name="length",
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
This is ugly because we are overriding a private method. In addition, a better place to indicate the field name to use for sorting would possibly be TrainingArguments
, but then we’d have to subclass or wrap that one too. But it can get the job done until the issue referenced above is discussed and resolved.