Bigbird pretraining

Hi,

I am curious about the new Bigbird model, and I’m trying to pretrain one for my language+domain. Running my usual pretraining script (see below) however gives me the following message, and as a result of not being able to use block sparse attention my GPU (Tesla V100) obviously runs out of memory in no time.

Attention type 'block_sparse' is not possible if sequence_length: 130 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3.Changing attention type to 'original_full'...

I understand what the problem is, but not how to solve it. How can I dynamically pad each minibatch to a multiple of that number of tokens in the formula for block sparse attention? I suppose I could pad all my samples to the max sequence length (4096), but that seems exceedingly wasteful as well. Any pointers on how to proceed here would be immensely appreciated.

Thanks a lot!

My current pretraining code:

tokenizer = BigBirdTokenizer(FLAGS.tokenizer)
print('tokenizer:', tokenizer)

config = BigBirdConfig(
    vocab_size=tokenizer.vocab_size,
    num_hidden_layers=6,
    max_position_embeddings=4096,
    attention_type="block_sparse",
)

model = BigBirdForMaskedLM(config=config)
print('model', model)
print(model.num_parameters(), 'parameters')

dataset = load_from_disk(FLAGS.data)
print('data loaded')
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir=FLAGS.output,
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=FLAGS.batchsize,
    save_steps=10_000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

print('start training!')
trainer.train()
print('done training!')
trainer.save_model(training_args.output_dir)

``