Accelerator .prepare() replaces custom DataLoader Sampler

I want to use a DataLoader that uses a custom sampler you can find at vision/references/classification/sampler.py at main 路 pytorch/vision 路 GitHub

When doing :

print(dataset, dataset.sampler)
dataset = accelerator.prepare(dataset)
print(dataset, dataset.sampler)

I get the following print:

<torch.utils.data.dataloader.DataLoader object at 0x7f987fbc4e50> <utils.imagenet_dataloader.RASampler object at 0x7f987fbc4e20>
<accelerate.data_loader.DataLoaderShard object at 0x7f98518b6da0> <torch.utils.data.sampler.SequentialSampler object at 0x7f98518b6b60>

This means that my RASampler got turned into a SequentialSampler.
Is this a normal behaviour? Since it seems I can鈥檛 manually restore my sampler afterwhile, this is quite a problem.
Could you tell me how to solve this problem?

1 Like