Data sampler based on number of tokens

I’m training the large bart model and I have to set my batch size very low in order to keep from getting OOM errors (with a 24GB 3090). The issue is that my sequence lengths are highly variable. The long ones which are close to the 1024 token limit need to be no more than 2 sequences per batch but many of the others could be batched 8 or more together. Does transformers have a way to batch based on the number of tokens instead of a fixed batch size? (BTW… I’m already using DataCollatorForSeq2Seq and group_by_length=True for dynamic padding of batches).

1 Like