Hi! I am working on a multilabel token classification problem. Did not have much success localizing the problem precisely, but I suppose it is somewhere in data collation and/or loading.
The dataset has the following format:
Dataset({
features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 1000
})
where (multi-hot encoded) labels are of shape (sequence_len, num_classes).
Using a default DataCollatorForTokenClassification
throws an error in torch_call()
because this class expects a 1D array of label_ids (as in usual sequence classification). This is why I implemented a custom data collator:
class DataCollatorForMultilabelTokenClassification(DataCollatorForTokenClassification):
def torch_call(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
batch = self.tokenizer.pad(
no_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if labels is None:
return batch
sequence_length = batch["input_ids"].shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch[label_name] = []
for label in labels:
padding = np.full((sequence_length - len(label), len(label[0])), self.label_pad_token_id, dtype=int)
padded = np.concatenate((np.array(label), padding)).tolist()
batch[label_name].append(padded)
else:
batch[label_name] = []
for label in labels:
padding = np.full((sequence_length - len(label), len(label[0])), self.label_pad_token_id, dtype=int)
padded = np.concatenate((padding, np.array(label))).tolist()
batch[label_name].append(padded)
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
return batch
data_collator = DataCollatorForMultilabelTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100)
But during the training the dataloader (which is default, I did not configure it at all) loads a batch with wrond dimensions:
input_ids
(and other tokenization fields) are of size (batch_size, sequence_len)labels
are expected to be of size (batch_size, sequence_len, num_classes), and this is what they are after data collatortorch_call()
, but they come to the loss function in the shape of (1, sequence_len, num_classes), so it it is only one sample of the batch…
Do I miss something? I would appreciate any directions:)