Identifying max_steps for generativeText Dataset For Next SentencePrediction

I am fine tuning a model on my domain using both MLM and NSP. I am using the TextDatasetForNextSentencePrediction for NSP and DataCollatorForLanguageModeling for MLM.

The problem is with TextDatasetForNextSentencePrediction as it loads everything in the memory. I tweaked it a little and made it such that it now generates the examples rather than storing them all in the memory.

But because I did that, now I don’t have a __len__ property for my dataset class and hence get the error:

ValueError: train_dataset does not implement __len__, max_steps has to be specified

I also tried to max out the value but it fails from within the torch library with error:

TypeError: object of type 'TextDatasetForNextSentencePrediction' has no no len() 

Now my question is what is the best way to approach this solution?
What can be done for the huge dataset that I have and TextDatasetForNextSentencePrediction?

Any pointers and/or directions to this are highly appreciated.

Thanks

Here is the changed code for the class:

class TextDatasetForNextSentencePrediction(Dataset):
    """
    This will be superseded by a framework-agnostic approach soon.
    """

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        file_path: str,
        block_size: int,
        overwrite_cache=False,
        short_seq_probability=0.1,
        nsp_probability=0.5,
    ):
        warnings.warn(
            DEPRECATION_WARNING.format(
                "https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_mlm.py"
            ),
            FutureWarning,
        )
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"

        self.short_seq_probability = short_seq_probability
        self.nsp_probability = nsp_probability
        self.tokenizer = tokenizer
        self.file_path = file_path
        self.block_size = block_size

        # Input file format:
        # (1) One sentence per line. These should ideally be actual sentences, not
        # entire paragraphs or arbitrary spans of text. (Because we use the
        # sentence boundaries for the "next sentence prediction" task).
        # (2) Blank lines between documents. Document boundaries are needed so
        # that the "next sentence prediction" task doesn't span between documents.
        #
        # Example:
        # I am very happy.
        # Here is the second sentence.
        #
        # A new document

    def create_documents_with_batching(self, document_batch=10240):
        """
            Reads document batch size from the file and sends to create examples
        """
        self.documents = [[]]
        with open(self.file_path, encoding="utf-8") as f:
            for line in f.readlines():
                line = line.strip()

                # Empty lines are used as document delimiters
                if not line and len(self.documents[-1]) != 0:
                    if len(self.documents) >= document_batch:
                        for doc_index, document in enumerate(self.documents):
                            yield from self.create_examples_from_document(document, doc_index)
                        self.documents = [[]]
                    else:
                        self.documents.append([])

                tokens = self.tokenizer.tokenize(line)
                tokens = self.tokenizer.convert_tokens_to_ids(tokens)
                if tokens:
                    self.documents[-1].append(tokens)
        
        for doc_index, document in enumerate(self.documents):
            yield from self.create_examples_from_document(document, doc_index)

    def create_examples_from_document(self, document: List[List[int]], doc_index: int):
        """Creates examples for a single document."""

        max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)

        # We *usually* want to fill up the entire sequence since we are padding
        # to `block_size` anyways, so short sequences are generally wasted
        # computation. However, we *sometimes*
        # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
        # sequences to minimize the mismatch between pretraining and fine-tuning.
        # The `target_seq_length` is just a rough target however, whereas
        # `block_size` is a hard limit.
        target_seq_length = max_num_tokens
        if random.random() < self.short_seq_probability:
            target_seq_length = random.randint(2, max_num_tokens)

        current_chunk = []  # a buffer stored current working segments
        current_length = 0
        i = 0

        while i < len(document):
            segment = document[i]
            current_chunk.append(segment)
            current_length += len(segment)
            if i == len(document) - 1 or current_length >= target_seq_length:
                if current_chunk:
                    # `a_end` is how many segments from `current_chunk` go into the `A`
                    # (first) sentence.
                    a_end = 1
                    if len(current_chunk) >= 2:
                        a_end = random.randint(1, len(current_chunk) - 1)

                    tokens_a = []
                    for j in range(a_end):
                        tokens_a.extend(current_chunk[j])

                    tokens_b = []

                    if len(current_chunk) == 1 or random.random() < self.nsp_probability:
                        is_random_next = True
                        target_b_length = target_seq_length - len(tokens_a)

                        # This should rarely go for more than one iteration for large
                        # corpora. However, just to be careful, we try to make sure that
                        # the random document is not the same as the document
                        # we're processing.

                        random_document = None
                        while True:
                            random_document_index = random.randint(0, len(self.documents) - 1)
                            random_document = self.documents[random_document_index]

                            if len(random_document) - 1 < 0:
                                continue

                            if random_document_index != doc_index:
                                break

                        random_start = random.randint(0, len(random_document) - 1)
                        for j in range(random_start, len(random_document)):
                            tokens_b.extend(random_document[j])
                            if len(tokens_b) >= target_b_length:
                                break
                        # We didn't actually use these segments so we "put them back" so
                        # they don't go to waste.
                        num_unused_segments = len(current_chunk) - a_end
                        i -= num_unused_segments
                    # Actual next
                    else:
                        is_random_next = False
                        for j in range(a_end, len(current_chunk)):
                            tokens_b.extend(current_chunk[j])

                    assert len(tokens_a) >= 1
                    assert len(tokens_b) >= 1

                    # add special tokens
                    input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
                    # add token type ids, 0 for sentence a, 1 for sentence b
                    token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)

                    example = {
                        "input_ids": torch.tensor(input_ids, dtype=torch.long),
                        "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                        "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
                    }

                    yield example

                current_chunk = []
                current_length = 0

            i += 1

    def __iter__(self, i):
        yield from self.create_documents_with_batching()