Hi guys,
first the question: If I use a custom batch sampler in the Dataloader of my hf trainer, how can I make sure the correct state of the random number generator is resumed from a checkpoint? The problem is that I need to set ignore_data_skip=True
in the TrainingArguments, simply because the dataskips takes way too long when resuming the training. From what I understand from the trainer source code it can use get_state()
and set_state()
of torch.Generator
to skip to the correct state immediately, however that does not seem to work out of the box in my case.
Background:
Some papers use a Round Robin Batch Sampling strategy to sample batches from a collection of datasets instead of simply concatenating the Datasets.
While they report performance improvements using this techniques, it also has practical advantages, e.g. that features from different datasets don’t need to be padded to have the same shape for stacking the tensors.
My approach so far looks like this:
class RoundRobinSampler:
def __init__(self, samplers: Sequence[Iterable], reinit: bool = False):
"""
a sampler that will cycle through the list of given samplers and
'forward' the next() call each sampler in turn ("round robin").
Args:
samplers (Sequence[Iterable]): the list of samplers that will be cycled
reinit (bool): when one of the samplers is exhausted, should it be re-
initialized or not?
(!) if yes, this will result in an infinite iterator, and epochs will not end
"""
self.samplers = samplers
self.reinit = reinit
def __iter__(self):
iterators = [iter(sampler) for sampler in self.samplers]
for i in cycle(range(len(iterators))):
it = iterators[i]
try:
yield next(it)
except StopIteration:
# current iterator is apparently exhausted
if not self.reinit: break
# re-initialize the iterator
it = iter(self.samplers[i])
iterators[i] = it
yield next(it)
def get_subset(length: int, i: int, k: int, offset: int = 0) -> Tuple[int, int]:
assert i < k
s = math.ceil(length / k) # size of one split
start = i * s
end = min((i + 1) * s, length)
return offset + start, offset + end
class DistributedRoundRobinBatchSampler:
"""
create one sampler for every dataset. The sampler will only sample from a split of the indices.
assume two datasets of size 10 and 3 gpus.
dataset 1 dataset 2
| 0 1 2 3 4 5 6 7 8 9 | 10 11 12 13 14 15 16 17 18 19 |
| | | | | | |
gpu 1 gpu 2 gpu 3 gpu 1 gpu 2 gpu 3
for this particular gpu (=rank), get a list of ranges representing dataset indices
the indices will be shuffled per gpu.
TODO gpu3 will get much less data than the others in this scenario.
"""
def __init__(
self,
lengths: List[int],
batch_size: int,
rank: int,
num_replicas: int,
drop_last: bool = False,
seed: int = 0,
shuffle: bool = True,
reinit: bool = False
):
offsets = [sum(lengths[:i]) for i in range(len(lengths))]
self.ranges = [get_subset(length, rank, num_replicas, offset) for offset, length in zip(offsets, lengths)]
self.seed = seed
self.shuffle = shuffle
self.drop_last = drop_last
self.epoch = 0
self.reinit = reinit
self.batch_size = batch_size
self.batch_start = 0
logger.info(f"initialized sampler {rank=} {num_replicas=} {seed=} {self.ranges} {batch_size=}")
def __iter__(self):
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
batch_samplers = [
BatchSampler(
SubsetRandomSampler(range(start, end), generator=g) if self.shuffle else range(start, end),
self.batch_size, self.drop_last
) for (start, end) in self.ranges
]
sampler = RoundRobinSampler(batch_samplers, reinit=self.reinit)
return iter(sampler)
def set_epoch(self, epoch: int):
self.epoch = epoch
In the trainer class, I override get_train_dataloader():
def get_train_dataloader(self) -> DataLoader:
dataset = self.train_dataset
if self.args.round_robin and isinstance(dataset, ConcatDataset):
sizes = [len(ds) for ds in dataset.datasets]
else:
sizes = [len(dataset)]
loader = DataLoader(
dataset,
batch_sampler=DistributedRoundRobinBatchSampler(
lengths=sizes,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
rank=self.args.process_index,
num_replicas=self.args.world_size,
seed=self.args.seed,
reinit=True,
),
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
collate_fn=self.data_collator
)
return loader
As you can see, I need to use a batch sampler to make sure the indices are sampled from separate datasets in turn, and I create a torch.Generator() inside my DistributedRoundRobinBatchSampler class.
I am happy about any recommendations regarding the problem, or general tips about the code or general approach!
Best, David