Speeding up custom data collator

Hi,

I have developed a custom collator for back-translation.

While training the model, this collator uses the model to generate noisy translated data from inputs and then create a new batch where the inputs are now considered labels and noisy translated data is considered input.

input_ids = torch.stack([torch.IntTensor(example['input_ids']) for example in features])
attention_mask = torch.stack([torch.IntTensor(example['attention_mask']) for example in features])
generation_config = GenerationConfig(max_new_tokens=self.max_length,
                                     decoder_start_token_id = decoder_start_token_id,
                                     pad_token_id=self.tokenizer.pad_token_id,
                                     bos_token_id=self.tokenizer.bos_token_id,
                                     eos_token_id=self.tokenizer.eos_token_id)
self.accelerator.wait_for_everyone()
self.model.eval()
with torch.no_grad():
            generated_tokens = self.accelerator.unwrap_model(self.model).generate(
                input_ids = input_ids.to(self.accelerator.device),
                attention_mask = attention_mask.to(self.accelerator.device),
                generation_config=generation_config
            ).to('cpu')
noisy_input = [self.tokenizer.decode(ids[1:]) for ids in generated_tokens]
batch = self.tokenizer(noisy_input, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt")
batch['labels'] = input_ids.clone().detach()
batch['labels'][batch['labels'] == self.tokenizer.pad_token_id] = -100
self.accelerator.wait_for_everyone()
self.model.train()

The approach works well, but the issue is that it is slow!

This is mainly because of using accelerator.unwrap_model(model).generate(...).

Does anyone have any suggestions on how I can speed up this collator?

1 Like