Would it be possible to implement and Iterable dataset with streaming and fast resume (no need to skip batches)

Training an LLM on a big dataset (depending on the project) takes considerable amount of time. However various types of interrupts may occur during the training, which requires resume.

In the case when using iterable dataset with streaming (which I suspect most do), when resuming huggingface runs over the batches (skips batches), until it reaches the appropriate place to continue the training. However, running over the batches can take a lot of time.

I wonder would not it be possible to implement an iterable dataset with fast resume, at least in the case when the training data consist of text files.

Considering this code, one can jump (file.seek function) in the file to an appropriate place (where it stopped) and not have iterate from the beginning and reach to the place where it stopped.

class DatasetWithFastResume:

    def __init__(self, file_path: str):
        self._file_path = file_path
        self._f = open(file_path, "r")
        self._line_number = 0

    def __next__(self) -> str:
        if self._f.closed: return None
        line = self._f.readline()[:-1] # exclude the newline character
        if not line:
            if not self._f.closed:
                self._f.close()
            return None
        self._line_number += 1
        return line
    
    def get_read_position(self) -> int:
        """
            returns an integer giving the file object’s current position in the file
            represented as number of bytes from the beginning of the file
        """
        return self._f.tell()

    def set_read_position(self, position: int):
        self._f.seek(position, 0)

    def get_state(self) -> Dict[str, Any]:
        return {
            "position": self.get_read_position(),
            "line_number": self._line_number
        }
    
    def load_state(self, state: Dict[str, Any]):
        self.set_read_position(state["position"])
        self._line_number = state["line_number"]

In the most general case what is suggested may not guarantee a resumed training equivalent to training without any resume. However, many applications don’t necessarily care about 100% correct resume, so I think this is an option that should be available, because it is not worth the time to spend waiting for the DataLoader to skip batches.

Hi ! This is not implemented yet but would be awesome to have indeed.

Here are some related discussions: Save and resume the state of a DataLoader · Issue #5454 · huggingface/datasets · GitHub

I recently opened a draft PR that implements state_dict() for IterableDataset and enables resuming: [Resumable IterableDataset] Add IterableDataset state_dict by lhoestq · Pull Request #6658 · huggingface/datasets · GitHub

It’s WIP/untested and relies on skipping shards and batches (no .seek()) but should be a good starting point

1 Like

Update: the iterable dataset streaming + resume feature has been released

docs: Stream

example with the StatefulDataLoader from torchdata:

from torchdata.stateful_dataloader import StatefulDataLoader
iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
dataloader = StatefulDataLoader(iterable_dataset, batch_size=32, num_workers=4)
# checkpoint
state_dict = dataloader.state_dict()  # uses iterable_dataset.state_dict() under the hood
# resume from checkpoint
dataloader.load_state_dict(state_dict)  # uses iterable_dataset.load_state_dict() under the hood
1 Like