Multilabel token classification (dataloader issues)

Hi! I am working on a multilabel token classification problem. Did not have much success localizing the problem precisely, but I suppose it is somewhere in data collation and/or loading.

The dataset has the following format:

    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000

where (multi-hot encoded) labels are of shape (sequence_len, num_classes).

Using a default DataCollatorForTokenClassification throws an error in torch_call() because this class expects a 1D array of label_ids (as in usual sequence classification). This is why I implemented a custom data collator:

class DataCollatorForMultilabelTokenClassification(DataCollatorForTokenClassification):
    def torch_call(self, features):

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

        no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]

        batch = self.tokenizer.pad(

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        if padding_side == "right":
            batch[label_name] = []
            for label in labels:
                padding = np.full((sequence_length - len(label), len(label[0])), self.label_pad_token_id, dtype=int)
                padded = np.concatenate((np.array(label), padding)).tolist()
            batch[label_name] = []
            for label in labels:
                padding = np.full((sequence_length - len(label), len(label[0])), self.label_pad_token_id, dtype=int)
                padded = np.concatenate((padding, np.array(label))).tolist()

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        return batch

data_collator = DataCollatorForMultilabelTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100)

But during the training the dataloader (which is default, I did not configure it at all) loads a batch with wrond dimensions:

  • input_ids (and other tokenization fields) are of size (batch_size, sequence_len)
  • labels are expected to be of size (batch_size, sequence_len, num_classes), and this is what they are after data collator torch_call(), but they come to the loss function in the shape of (1, sequence_len, num_classes), so it it is only one sample of the batch…

Do I miss something? I would appreciate any directions:)