Implementing Round Robin Batch Sampling with hf trainer

Hi guys,

first the question: If I use a custom batch sampler in the Dataloader of my hf trainer, how can I make sure the correct state of the random number generator is resumed from a checkpoint? The problem is that I need to set ignore_data_skip=True in the TrainingArguments, simply because the dataskips takes way too long when resuming the training. From what I understand from the trainer source code it can use get_state() and set_state() of torch.Generator to skip to the correct state immediately, however that does not seem to work out of the box in my case.

Background:
Some papers use a Round Robin Batch Sampling strategy to sample batches from a collection of datasets instead of simply concatenating the Datasets.
While they report performance improvements using this techniques, it also has practical advantages, e.g. that features from different datasets don’t need to be padded to have the same shape for stacking the tensors.
My approach so far looks like this:

class RoundRobinSampler:

    def __init__(self, samplers: Sequence[Iterable], reinit: bool = False):
        """
        a sampler that will cycle through the list of given samplers and 
        'forward' the next() call each sampler in turn ("round robin").

        Args:
            samplers (Sequence[Iterable]): the list of samplers that will be cycled
            reinit (bool): when one of the samplers is exhausted, should it be re-
                initialized or not?
                (!) if yes, this will result in an infinite iterator, and epochs will not end
        """
        self.samplers = samplers
        self.reinit = reinit
        
    def __iter__(self):
        iterators = [iter(sampler) for sampler in self.samplers]
        
        for i in cycle(range(len(iterators))):
            it = iterators[i]
            
            try:
                yield next(it)

            except StopIteration:
                # current iterator is apparently exhausted
                if not self.reinit: break

                # re-initialize the iterator
                it = iter(self.samplers[i])
                iterators[i] = it
                yield next(it)

def get_subset(length: int, i: int, k: int, offset: int = 0) -> Tuple[int, int]:
    assert i < k
    s = math.ceil(length / k) # size of one split
    start = i * s
    end = min((i + 1) * s, length)
    return offset + start, offset + end


class DistributedRoundRobinBatchSampler:
    """
    create one sampler for every dataset. The sampler will only sample from a split of the indices.
    assume two datasets of size 10 and 3 gpus.

                dataset 1                     dataset 2
    | 0  1  2  3  4  5  6  7  8  9 | 10 11 12 13 14 15 16 17 18 19 |
    |            |          |      |            |           |      |
         gpu 1      gpu 2     gpu 3     gpu 1       gpu 2     gpu 3

    for this particular gpu (=rank), get a list of ranges representing dataset indices
    the indices will be shuffled per gpu.

    TODO gpu3 will get much less data than the others in this scenario.
    """

    def __init__(
        self,
        lengths: List[int],
        batch_size: int,
        rank: int,
        num_replicas: int,
        drop_last: bool = False,
        seed: int = 0,
        shuffle: bool = True,
        reinit: bool = False
    ):
        offsets = [sum(lengths[:i]) for i in range(len(lengths))]
        self.ranges = [get_subset(length, rank, num_replicas, offset) for offset, length in zip(offsets, lengths)]
        self.seed = seed
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.epoch = 0
        self.reinit = reinit
        self.batch_size = batch_size
        self.batch_start = 0
        logger.info(f"initialized sampler {rank=} {num_replicas=} {seed=} {self.ranges} {batch_size=}")
        
    def __iter__(self):

        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        batch_samplers = [
            BatchSampler(
                SubsetRandomSampler(range(start, end), generator=g) if self.shuffle else range(start, end),
                self.batch_size, self.drop_last
            ) for (start, end) in self.ranges
        ]

        sampler = RoundRobinSampler(batch_samplers, reinit=self.reinit)
        return iter(sampler)

    def set_epoch(self, epoch: int):
        self.epoch = epoch

In the trainer class, I override get_train_dataloader():

    def get_train_dataloader(self) -> DataLoader:
        dataset = self.train_dataset
        
        if self.args.round_robin and isinstance(dataset, ConcatDataset):
            sizes = [len(ds) for ds in dataset.datasets]
        else:
            sizes = [len(dataset)]

        loader = DataLoader(
            dataset,
            batch_sampler=DistributedRoundRobinBatchSampler(
                lengths=sizes,
                batch_size=self.args.train_batch_size,
                drop_last=self.args.dataloader_drop_last,
                rank=self.args.process_index,
                num_replicas=self.args.world_size,
                seed=self.args.seed,
                reinit=True,
            ),
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            collate_fn=self.data_collator
        )
        return loader

As you can see, I need to use a batch sampler to make sure the indices are sampled from separate datasets in turn, and I create a torch.Generator() inside my DistributedRoundRobinBatchSampler class.

I am happy about any recommendations regarding the problem, or general tips about the code or general approach!

Best, David