Multilabel+multiclass classification

Hi,
I am interested in solving a multilabel+multiclass classification problem, i. e. I have 9 labels and each of these 9 labels can have more than 2 classes.
I attempted to build 9 classification heads and calculate the loss by averaging the loss of all 9 heads.
My labels are stored in a list containing 9 elements with the labels for every head. However, the DataCollatorForTokenClassification does not seem to support lists, and I can’t figure out how I can make my data compatible with a DataCollator.
Is it possible to create a custom DataCollator for this purpose? Or did I miss something which solves my problem?
Thank you very much in advance.