No key 'messages' found

DatasetDict({
train: Dataset({
features: [‘text’, ‘messages’],
num_rows: 20347
})
test: Dataset({
features: [‘text’, ‘messages’],
num_rows: 2261
})
})

the above is my dataset format

from trl import setup_chat_format

model,tokenizer = setup_chat_format(model,tokenizer)

args = SFTConfig(
output_dir = “lora_model/”,
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
learning_rate = 2e-05,
gradient_accumulation_steps = 2,
max_steps = 300,
logging_strategy = “steps”,
logging_steps = 25,
save_strategy = “steps”,
save_steps = 25,
eval_strategy = “steps”,
eval_steps = 25,
fp16 = True,
data_seed=42,
max_seq_length = 2048,
dataset_text_field = “messages”
)

trainer = SFTTrainer(
model = model,
args = args,
processing_class = tokenizer,
train_dataset = dataset[‘train’],
eval_dataset = dataset[‘test’])

im performing fine tuning using sft, im getting error when i run below cell

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-31-9e22427f7dd7> in <cell line: 1>()
----> 1 trainer = SFTTrainer(
      2     model = model,
      3     args = args,
      4     processing_class = tokenizer,
      5     train_dataset = dataset['train'],

/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py in wrapped_func(*args, **kwargs)
    163                 warnings.warn(message, FutureWarning, stacklevel=2)
    164 
--> 165             return func(*args, **kwargs)
    166 
    167         return wrapped_func

/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, processing_class, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics, peft_config, formatting_func)
    196         preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
    197         if preprocess_dataset:
--> 198             train_dataset = self._prepare_dataset(
    199                 train_dataset, processing_class, args, args.packing, formatting_func, "train"
    200             )

/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in _prepare_dataset(self, dataset, processing_class, args, packing, formatting_func, dataset_name)
    409             if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
    410                 map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
--> 411             dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
    412 
    413             # Pack or truncate

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    558             "columns": self._format_columns,
    559             "output_all_columns": self._output_all_columns,
--> 560         }
    561         # apply actual function
    562         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3071             except NonExistentDatasetError:
   3072                 pass
-> 3073             if transformed_dataset is None:
   3074                 with hf_tqdm(
   3075                     unit=" examples",

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in _map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
   3444                     tasks.append(loop.create_task(async_apply_function(example, i, offset=offset)))
   3445                     # keep the total active tasks under a certain number
-> 3446                     if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
   3447                         done, pending = loop.run_until_complete(
   3448                             asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
   3336                 effective_indices = [i + offset for i in indices] if isinstance(indices, list) else indices + offset
   3337             additional_args = ()
-> 3338             if with_indices:
   3339                 additional_args += (effective_indices,)
   3340             if with_rank:

/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in <lambda>(ex)
    409             if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
    410                 map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
--> 411             dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
    412 
    413             # Pack or truncate

/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in __getitem__(self, key)
    275 
    276     def __len__(self):
--> 277         return len(self.data)
    278 
    279     def __getitem__(self, key):

KeyError: 'messages'
1 Like

It seems to appear when you try to access an item that doesn’t exist.

the dataset contains the messages field

1 Like