Ok I think this is the code:
def custom_collate_fn_train_on_first_eos_occurrence(data: list[dict[str, str]], tokenizer: PreTrainedTokenizer) -> dict[str, torch.Tensor]:
# Ensure tokenizer has a padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Extract sequences
sequences: list[str] = [example.get("text", "") or "" for example in data]
# Tokenize the sequences
tokenized_data = tokenizer(sequences, padding="max_length", max_length=context_length, truncation=True, return_tensors="pt")
# Clone input_ids to labels
tokenized_data["labels"] = tokenized_data["input_ids"].clone()
# Set the mask value for the first eos_token in each sequence to 1
eos_token_id = tokenizer.eos_token_id
for idx, input_ids in enumerate(tokenized_data["input_ids"]):
# Find all occurrences of eos_token
eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)[0]
if eos_positions.nelement() > 0: # Check if eos_token is present
first_eos_position = eos_positions[0]
tokenized_data["attention_mask"][idx, first_eos_position] = 1 # Set the mask value to 1
# Assert that the label for the first occurrence of eos_token is eos_token_id
assert tokenized_data["labels"][idx, first_eos_position] == eos_token_id, "The label for the first eos_token is incorrect!"
# For all subsequent occurrences of eos_token, set their labels to -100
for subsequent_eos_position in eos_positions[1:]:
assert tokenized_data["labels"][idx, subsequent_eos_position] == -100, "The label for the first eos_token is incorrect!"
# tokenized_data["labels"][idx, subsequent_eos_position] = -100
return tokenized_data