Need a definite formula to decide the value to set max_steps when using streaming dataset.
There are several questions raised about
max_steps when using streaming dataset.
- Explicitly set number of training steps using Trainer
- Streaming dataset into Trainer: does not implement len, max_steps has to be specified](Streaming dataset into Trainer: does not implement __len__, max_steps has to be specified)
According to the documents, it is set to the total number of training steps which should be number of total mini-batches.
If set to a positive number, the total number of training steps to perform. Overrides num_train_epochs.
I am afraid it is not clear.
Suppose there is a small dataset of 2048 rows in the train split of a Huggingface Dataset, and the training arguments are set as below except max_steps as below.
training_args = TrainingArguments( output_dir="bloom_finetuned", max_steps=MAX_STEPS, num_train_epochs=3, per_device_train_batch_size=1, per_device_eval_batch_size=1, learning_rate=2e-5, weight_decay=0.01, no_cuda=False, )
Then for a system that has single GPU:
MAX_STEPS =num_train_epochs * num_rows_in_train / per_device_train_batch_size
- num_rows_in_train=2048 is total number of records in the training dataset
- per_device_train_batch_size=1 is the batch size to be sent to GPU
- num_train_epochs=1 is the number of epochs to run
Is this correct?
If there are multiple GPU devices being used in parallel, then:
MAX_STEPS =num_train_epochs * num_rows_in_train / per_device_train_batch_size / num_gpu_devices
Is this correct?
The Trainer training shows huge number of epochs for the above setting. Is this supposed to be like this?
***** Running training ***** Num examples = 6,144 Num Epochs = 9,223,372,036,854,775,807 <----- Instantaneous batch size per device = 1 Total train batch size (w. parallel, distributed & accumulation) = 1 Gradient Accumulation steps = 1 Total optimization steps = 6,144 Number of trainable parameters = 559,214,592
per_device_train_batch_size=1 is because training BLOOM takes up GPU memory and cannot set > 1.