but the distribution shift is conditionally since decoders are autoregressive i.e., only (or mostly) eos probability will be increased given any number of eos tokens have already been seen. Is my prediction. But I’ve seen in other places that a fine tuned model with eos = pad
predicts eos way too much e.g., only predicts eos. So one way to fix this is to mask the remaining eos if this is really an issue.
Code for that:
# -- Define custom collate function
def custom_collate_fn(data: list[dict[str, str]], tokenizer: PreTrainedTokenizer) -> dict[str, torch.Tensor]:
""" trains on first occurence of eos
ref: https://discuss.huggingface.co/t/why-does-the-falcon-qlora-tutorial-code-use-eos-token-as-pad-token/45954/13?u=brando
ref: https://chat.openai.com/share/02d16770-a1f3-4bf4-8fc2-464286daa8a1
ref: https://claude.ai/chat/80565d1f-ece3-4fad-87df-364ce57aec15 on when to call .clone()
"""
# we are training full context length forllama so remove code bellow, if it triesto pad hopefully it throws an error
# -- Ensure tokenizer has a padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# -- Extract sequences
# sequences: list[str] = [example.get("text", "") or "" for example in data]
sequences: list[str] = []
for idx, example in enumerate(data):
# Retrieve the value for "text" from the dictionary or default to an empty string if not present or falsy. ref: https://chat.openai.com/share/bead51fe-2acf-4f05-b8f7-b849134bbfd4
text: str = example.get("text", "") or ""
sequences.append(text)
# -- Tokenize the sequences
tokenized_data = tokenizer(sequences, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
tokenized_data["labels"] = tokenized_data["input_ids"].clone() # labels is hardcoded in HF so put it!
# -- Set the mask value for the first eos_token in each sequence to 1
eos_token_id = tokenizer.eos_token_id
for idx, input_ids in enumerate(tokenized_data["input_ids"]):
# Find all occurrences of eos_token
eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)[0]
if eos_positions.nelement() > 0: # Check if eos_token is present
first_eos_position = eos_positions[0]
tokenized_data["attention_mask"][idx, first_eos_position] = 1 # Set the mask value to 1
# Assert that the label for the first occurrence of eos_token is eos_token_id
assert tokenized_data["labels"][idx, first_eos_position] == eos_token_id, "The label for the first eos_token is incorrect!"
# For all subsequent occurrences of eos_token, set their labels to -100
for subsequent_eos_position in eos_positions[1:]:
tokenized_data["labels"][idx, subsequent_eos_position] = -100
assert tokenized_data["labels"][idx, subsequent_eos_position] == -100, "The label for the subsequent_eos_position incorrect! Should be -100."
return tokenized_data