Hi,
I have developed a custom collator for back-translation.
While training the model, this collator uses the model to generate noisy translated data from inputs and then create a new batch where the inputs are now considered labels and noisy translated data is considered input.
input_ids = torch.stack([torch.IntTensor(example['input_ids']) for example in features])
attention_mask = torch.stack([torch.IntTensor(example['attention_mask']) for example in features])
generation_config = GenerationConfig(max_new_tokens=self.max_length,
decoder_start_token_id = decoder_start_token_id,
pad_token_id=self.tokenizer.pad_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id)
self.accelerator.wait_for_everyone()
self.model.eval()
with torch.no_grad():
generated_tokens = self.accelerator.unwrap_model(self.model).generate(
input_ids = input_ids.to(self.accelerator.device),
attention_mask = attention_mask.to(self.accelerator.device),
generation_config=generation_config
).to('cpu')
noisy_input = [self.tokenizer.decode(ids[1:]) for ids in generated_tokens]
batch = self.tokenizer(noisy_input, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt")
batch['labels'] = input_ids.clone().detach()
batch['labels'][batch['labels'] == self.tokenizer.pad_token_id] = -100
self.accelerator.wait_for_everyone()
self.model.train()
The approach works well, but the issue is that it is slow!
This is mainly because of using accelerator.unwrap_model(model).generate(...)
.
Does anyone have any suggestions on how I can speed up this collator?