TrainingArguments class - max_steps formula when using streaming dataset

Objective

Need a definite formula to decide the value to set max_steps when using streaming dataset.

Background

There are several questions raised about max_steps when using streaming dataset.

According to the documents, it is set to the total number of training steps which should be number of total mini-batches.

If set to a positive number, the total number of training steps to perform. Overrides num_train_epochs.

I am afraid it is not clear.

Question

Suppose there is a small dataset of 2048 rows in the train split of a Huggingface Dataset, and the training arguments are set as below except max_steps as below.

training_args = TrainingArguments(
    output_dir="bloom_finetuned",
    max_steps=MAX_STEPS,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=2e-5,
    weight_decay=0.01, 
    no_cuda=False,
)

Then for a system that has single GPU:

MAX_STEPS =num_train_epochs * num_rows_in_train / per_device_train_batch_size

Where:

  • num_rows_in_train=2048 is total number of records in the training dataset
  • per_device_train_batch_size=1 is the batch size to be sent to GPU
  • num_train_epochs=1 is the number of epochs to run

Is this correct?

If there are multiple GPU devices being used in parallel, then:

MAX_STEPS =num_train_epochs * num_rows_in_train / per_device_train_batch_size / num_gpu_devices

Is this correct?

Confirmation

The Trainer training shows huge number of epochs for the above setting. Is this supposed to be like this?

***** Running training *****
  Num examples = 6,144
  Num Epochs = 9,223,372,036,854,775,807      <-----
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 6,144
  Number of trainable parameters = 559,214,592

The reason per_device_train_batch_size=1 is because training BLOOM takes up GPU memory and cannot set > 1.

I can’t answer your questions, but I did see Num Epochs hit this large number. In the code here, if you set max_steps then that overrides num_train_epochs. They assign sys.maxsize so the rest of the code basically ignores num_train_epochs.