Data Preparation for CausalLM

Hi together,

I want to train a CausalLM (gpt2) according to this course.
Hereby, I am using the DataCollatorforLM with the flag mlm set to False.
However, I am still unsure about how exactly the batches are generated from one sample.
Given a tokenized sample

[10, 14, 36, 28, 30, 31, 77, 100, 101]

the data collator is returning the input and label for training

input = [10, 14, 36, 28, 30, 31, 77, 100, 101]
label = [10, 14, 36, 28, 30, 31, 77, 100, 101]

In the documentation of the datacollator I already found, that the labels will be shifted right automatically during training by the model. Still, for causal language modeling I would want to create multiple inputs and labels of the given sample, so that the model will have to predict the correct token at each position, hence:

input = [
	[10,  0,  0,  0,  0,  0,  0,   0,   0]
	[10, 14,  0,  0,  0,  0,  0,   0,   0]
	[10, 14, 36,  0,  0,  0,  0,   0,   0]
	[10, 14, 36, 28,  0,  0,  0,   0,   0]
	[10, 14, 36, 28, 30,  0,  0,   0,   0]
	[10, 14, 36, 28, 30, 31,  0,   0,   0]
	[10, 14, 36, 28, 30, 31, 77,   0,   0]
	[10, 14, 36, 28, 30, 31, 77, 100,   0]
]
label = [
	[10, 14,  0,  0,  0,  0,  0,   0,   0]
	[10, 14, 36,  0,  0,  0,  0,   0,   0]
	[10, 14, 36, 28,  0,  0,  0,   0,   0]
	[10, 14, 36, 28, 30,  0,  0,   0,   0]
	[10, 14, 36, 28, 30, 31,  0,   0,   0]
	[10, 14, 36, 28, 30, 31, 77,   0,   0]
	[10, 14, 36, 28, 30, 31, 77, 100,   0]
	[10, 14, 36, 28, 30, 31, 77, 100, 101]
]

My question is now:
Is this done automatically by the „CausalLM“ model or do I have to implement this by myself in a custom dataloader/dataset?

I just figured it out by myself.
As explained in this video, assuring that the model predicts to correct next token at each position is usually done by using a triangular mask in the self-attention layer and not by passing all possibilities as a separate sample.
By looking at the GPT2 implementation from huggingface, I found that the GPT2Attention module implements a triangular causal_mask for this, thus there should be no need for preprocessing the data manually as asked above :slight_smile: