I found out that my data collator takes only âattention maskâ as inputs. I do not know where the other fields disappear
class T2TDataCollator():
def __init__(self, tokenizer, mode='training'):
self.tokenizer = tokenizer
self.mode = mode
def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
input_ids = torch.stack([example['source_ids'] for example in batch])
target_ids = torch.stack([example['target_ids'] for example in batch])
attention_mask = torch.stack([example['attention_mask'] for example in batch])
.....
The error is thrown in this part.