Using IterableDataset, Trainer never calls `set_epoch` to increment epoch

I am using a Huggingface implementation of IterableDataset with the set_epoch method with the standard Trainer class. However, during training the _epoch attribute of the dataset is never changed.(https://github.com/huggingface/datasets/blob/0cc77d7f45c73698c31eab4f8cfff901044d0020/src/datasets/iterable_dataset.py#L1829)

In the Trainer docs, it says for an IterableDataset to “have a set_epoch() method that internally sets the seed of the RNGs used”. Im not sure how to use this if Trainer doesn’t internally call this at every epoch.

Should there be another option in the IterableDatasetShard to resolve this? https://github.com/huggingface/transformers/blob/63864e057fd4ecbf54c77599702873f7be871e65/src/transformers/trainer_pt_utils.py#L809

Transformers issue: Using IterableDataset, Trainer never calls `set_epoch` to increment and resets the epochs to 0 at the beginning of each epoch. · Issue #26541 · huggingface/transformers · GitHub

I also think it’s the Trainer’s job to call set_epoch, so let’s see Transformers folks’ opinion on this.