Can only automatically infer lengths for datasets whose items are dictionaries with an 'input_ids' key

Hi,

I’m trying to train a Llama model on a custom dataset. This is my dataset:

READER_MODEL_NAME = "NousResearch/Llama-2-7b-chat-hf"


class RedditDataset(Dataset):
    def __init__(
        self,
        reddit_data: pd.DataFrame,
        rag,
        tokenizer,
    ):
        self.rag = rag
        self.data = reddit_data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        question = row["question"]
        context = self.rag.retrieve_context(question)

        prompt = [
            {
                "role": "system",
                "content": """Using the information contained in the context,
give a comprehensive and concise answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
Provide the number of the rule when relevant.
If the answer cannot be deduced from the context, do not give an answer.
The questions are related with Magic The Gathering card game.""",
            },
            {
                "role": "user",
                "content": f"""Context:
{context}
---
Now here is the question you need to answer.

Question: {question}""",
            },
            {"role": "assistant", "content": f"Answer: {row['answer']}"},
        ]

        return self.tokenizer.apply_chat_template(prompt, tokenize=False)

tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Loading Reddit dataset")
reddit_df = pd.read_csv("data/reddit/reddit_qa_dataset.csv")
train, test = train_test_split(reddit_df, test_size=0.2)

print("Creating datasets")
train_dataset = RedditDataset(train, rag, tokenizer)
test_dataset = RedditDataset(test, rag, tokenizer)

As you can see it contains a rag system and a prompt template. It shouldn’t be an issue because this is the output generated:

train_dataset[0]
{'text': '<s>[INST] <<SYS>>\nUsing the information contained in the context,\ngive a comprehensive and concise answer to the question.\nRespond only to the question asked, response should be concise and relevant to the question.\nProvide the number of the rule when relevant.\nIf the answer cannot be deduced from the context, do not give an answer.\nThe questions are related with Magic The Gathering card game.\n<</SYS>>\n\nContext:\n\nExtracted documents:\nDocument 0:::\nName: Rabbit Battery\nMana Cost: red\nType: Artifact Creature — Equipment Rabbit\nText: Haste\nEquipped creature gets +1/+1 and has haste.\nReconfigure red,   red: Attach to target creature you control; or unattach from a creature. Reconfigure only as a sorcery. While attached, this isn\'t a creature.)\nStats: 1 power, 1 toughness\nRules:\n1. Although it causes an Equipment to become attached to a creature, reconfigure is not an “equip ability” for the purpose of cards like Fighter Class and Leonin Shikari.\n2. An Equipment creature can never become attached to itself. If an effect tries to do this, nothing happens.\n3. An Equipment creature with reconfigure can be attached to creatures by effects other than its reconfigure ability, such as the activated ability of Brass Squire.\n4. An Equipment doesn\'t become tapped when the permanent it\'s attached to becomes tapped. For example, if you attack with a creature that is equipped with Acquisition Octopus, then use reconfigure to unattach Acquisition Octopus after combat, the Octopus will be untapped and could be used to block during your opponent\'s turn.\n5. As soon as an Equipment creature with reconfigure stops being a creature, any Equipment and Auras with enchant creature abilities become unattached. Auras that can enchant an Equipment that isn\'t a creature remain attached to it.\n6. Attaching an Equipment with reconfigure to a creature causes that Equipment to stop being a creature until it becomes unattached. It also loses any creature subtypes it had.\n7. If a permanent with reconfigure is somehow still a creature after it becomes attached (perhaps due to an effect like that of March of the Machines), it immediately becomes unattached from the equipped creature.\n8. If an Equipment with reconfigure somehow loses its abilities while it is attached, the effect causing it to not be a creature continues to apply until it becomes unattached.Document 1:::\nName: Kodama of the Center Tree\nMana Cost: 4 colorless, green\nType: Legendary Creature — Spirit\nText: Kodama of the Center Tree\'s power and toughness are each equal to the number of Spirits you control.\nKodama of the Center Tree has soulshift X, where X is the number of Spirits you control. (When this creature dies, you may return target Spirit card with mana value X or less from your graveyard to your hand.)\nStats:  power,  toughness\nRules:\n1. Kodama of the Center Tree can return itself to its owner’s hand if you control five or more Spirits when it is put into a graveyard from the battlefield.\n2. Soulshift is a leaves the battlefield trigger, so the gamestate is referenced immediately before the Soulshift trigger to determine the value of X. Soulshift X includes Kodama of the Center Tree. So, X is always at least 1.\n\nLegal in: commander, duel, legacy, modern, oathbreaker, penny, predh, vintageDocument 2:::\nName: Find // Finality\nMana Cost: black green, black green\nType: Sorcery\nText: Return up to two target creature cards from your graveyard to your hand.\n\nRules:\n1. Finality affects only creatures on the battlefield at the time it resolves. Creatures that enter the battlefield later in the turn won’t get -4/-4.\n2. Finality doesn’t target the creature to receive +1/+1 counters. You can cast it even if you control no creatures.\n\nLegal in: brawl, commander, duel, explorer, gladiator, historic, legacy, modern, oathbreaker, penny, pioneer, timeless, vintageDocument 3:::\nName: From the Rubble\nMana Cost: 4 colorless, white, white\nType: Enchantment\nText: As From the Rubble enters the battlefield, choose a creature type.\nAt the beginning of your end step, return target creature card of the chosen type from your graveyard to the battlefield with a finality counter on it. (If a creature with a finality counter on it would die, exile it instead.)\n\nRules:\n1. Finality counters aren\'t keyword counters, and a finality counter doesn\'t give any abilities to the permanent it\'s on. If that permanent loses its abilities and then would go to a graveyard, it will still be exiled instead.\n2. Finality counters don\'t stop permanents from going to zones other than the graveyard from the battlefield. For example, if a permanent with a finality counter on it would be put into its owner\'s hand from the battlefield, it does so normally.\n3. Finality counters work on any permanent, not only creatures. If a permanent with a finality counter on it would go to a graveyard from the battlefield, exile it instead.\n4. Multiple finality counters on a single permanent are redundant.\n\nLegal in: commander, duel, legacy, oathbreaker, vintageDocument 4:::\nName: Find // Finality1\nMana Cost: 4 colorless, black, green\nType: Sorcery\nText: You may put two +1/+1 counters on a creature you control. Then all creatures get -4/-4 until end of turn.\n\nRules:\n1. Finality affects only creatures on the battlefield at the time it resolves. Creatures that enter the battlefield later in the turn won’t get -4/-4.\n2. Finality doesn’t target the creature to receive +1/+1 counters. You can cast it even if you control no creatures.\n\nLegal in: brawl, commander, duel, explorer, gladiator, historic, legacy, modern, oathbreaker, penny, pioneer, timeless, vintage\n---\nNow here is the question you need to answer.\n\nQuestion: Rabbit Battery + Find // Finality\nHey all, \n\nI had a question about a ruling. I had a Kodama of the West Tree and a Rabbit Battery equipped to it. My opponent played Find // Finality which gave all creatures -4/-4 (until end of turn) how does that work on the stack? I killed my Kodama of the West Tree but am unsure if the Rabbit Battery will survive since it was equipment when the spell was played, or if the "end of turn" part will kill it after it becomes a creature again.\n\nThank you [/INST] Answer: Rabbit Battery survives.\n\nThe text "all creatures get -4/-4 until end of turn" means that you find all the creatures, take note of what they are, and give them -4/-4 until end of turn. While the time we "find all the creatures", Rabbit Battery is not a creature, so it\'s not affected. (Note: Similarly, creatures that enter the battlefield later this turn are also not affected.)\n\nAnother thing: technically, the thing that kills Kodama is the state-based action (SBA), which is performed only immediately after Find//Finality has finished resolving and has been put into the graveyard already. </s>'}

I’m trying to train this model:

# Quantization Config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)
# Model
base_model = AutoModelForCausalLM.from_pretrained(
    READER_MODEL_NAME, quantization_config=quant_config, device_map={"": 0}
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1

These are the training configs:

from transformers import (
    TrainingArguments,
)

from peft import LoraConfig
from trl import SFTTrainer

# LoRA Config
peft_parameters = LoraConfig(
    lora_alpha=16, lora_dropout=0.1, r=8, bias="none", task_type="CAUSAL_LM"
)

# Training Params
train_params = TrainingArguments(
    output_dir="./results_modified",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    # report_to="tensorboard",
)

# Trainer
fine_tuning = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    peft_config=peft_parameters,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=train_params,
)

# Training
fine_tuning.train()

And I’m getting this error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 45
     35 fine_tuning = SFTTrainer(
     36     model=base_model,
     37     train_dataset=train_dataset,
   (...)
     41     args=train_params,
     42 )
     44 # Training
---> 45 fine_tuning.train()

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/trl/trainer/sft_trainer.py:361, in SFTTrainer.train(self, *args, **kwargs)
    358 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
    359     self.model = self._trl_activate_neftune(self.model)
--> 361 output = super().train(*args, **kwargs)
    363 # After training we make sure to retrieve back the original forward pass method
    364 # for the embedding layer by removing the forward post hook.
    365 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1883         hf_hub_utils.enable_progress_bars()
   1884 else:
-> 1885     return inner_training_loop(
   1886         args=args,
   1887         resume_from_checkpoint=resume_from_checkpoint,
   1888         trial=trial,
   1889         ignore_keys_for_eval=ignore_keys_for_eval,
   1890     )

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/transformers/trainer.py:1914, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1912 logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
   1913 # Data loader and number of training steps
-> 1914 train_dataloader = self.get_train_dataloader()
   1915 if self.is_fsdp_xla_v2_enabled:
   1916     train_dataloader = tpu_spmd_dataloader(train_dataloader)

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/transformers/trainer.py:892, in Trainer.get_train_dataloader(self)
    883 dataloader_params = {
    884     "batch_size": self._train_batch_size,
    885     "collate_fn": data_collator,
   (...)
    888     "persistent_workers": self.args.dataloader_persistent_workers,
    889 }
    891 if not isinstance(train_dataset, torch.utils.data.IterableDataset):
--> 892     dataloader_params["sampler"] = self._get_train_sampler()
    893     dataloader_params["drop_last"] = self.args.dataloader_drop_last
    894     dataloader_params["worker_init_fn"] = seed_worker

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/transformers/trainer.py:854, in Trainer._get_train_sampler(self)
    852         lengths = None
    853     model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
--> 854     return LengthGroupedSampler(
    855         self.args.train_batch_size * self.args.gradient_accumulation_steps,
    856         dataset=self.train_dataset,
    857         lengths=lengths,
    858         model_input_name=model_input_name,
    859     )
    861 else:
    862     return RandomSampler(self.train_dataset)

File ~/miniconda3/envs/gatherer-sage/lib/python3.9/site-packages/transformers/trainer_pt_utils.py:650, in LengthGroupedSampler.__init__(self, batch_size, dataset, lengths, model_input_name, generator)
    645     model_input_name = model_input_name if model_input_name is not None else "input_ids"
    646     if (
    647         not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
    648         or model_input_name not in dataset[0]
    649     ):
--> 650         raise ValueError(
    651             "Can only automatically infer lengths for datasets whose items are dictionaries with an "
    652             f"'{model_input_name}' key."
    653         )
    654     lengths = [len(feature[model_input_name]) for feature in dataset]
    655 elif isinstance(lengths, torch.Tensor):

ValueError: Can only automatically infer lengths for datasets whose items are dictionaries with an 'input_ids' key.

I don’t understand what is happening because the same config with a different dataset works. Like in this article: https://deci.ai/blog/fine-tune-llama-2-with-lora-for-question-answering/. And the input structure of my dataset is the same as them.

Why my SFTTrainer ask for “input_ids” and can’t gather them automatically, it is what is supposed to do with the tokenizer and dataset_text_field parameters, right?

1 Like