How does `max_steps` affect the number of samples the model "sees"?


I am finetuning starcoderbase-3b with a small local dataset that I have with around ~7000 programs. I’ve looked online, but I am still confused as to how the max_steps parameter affects how many times the model sees each sample (program).

If I set max_steps = 21000 will the model see each program 3 times? I.e. something like 3 epochs of training? If I set max_steps < 7000 will the model not see all the programs? And I am asking this, because no matter the max_steps value, at the end of the training I seem to get 1 epoch.

Currently, I have a batch_size = 1 and gradient_accumulation_steps = 4? Do those numbers also play a role into this?

Thank you

Hi @gsakkas,

The max steps arg is referring to the number of optimization steps that run during training. That is, the number of forward + backward passes. Because forward/backward passes are batched, batch size and gradient accumulation do factor into it.

If you have 7000 examples, batch size = 1, and gradient accumulation = 4, then in order for training to see every example just once, you need 7000 / (4 * 1) = 1750 steps. If you had a batch size of 8 and gradient accumulation of 4, you’d need 219 steps for the model to see every example once.

So given your config, 21000 steps would mean the model sees every example 12 times.

It’s often less confusing to set the number of epochs instead of the number of steps. 1 epoch means that no matter what the batch size/gradient accumulation is, every example will be seen just once. 2 epochs, twice, etc.

1 Like

Thank you for your elaborate response!

Just to clarify, when I set max_steps, even with a higher value that, as you said, is enough for 12 epochs, the reported epoch and the end of the training seems to be always 1. Is the true number of epochs not reported when steps are used? I think my confusion stemmed from that.

I’m not sure what’s going on in that case. My read of the trainer code is that num_epochs should be set to 12 (looking at this part). When this logger statement is reached when you train, does it just say 1?

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.