DatasetDict({
train: Dataset({
features: [âtextâ, âmessagesâ],
num_rows: 20347
})
test: Dataset({
features: [âtextâ, âmessagesâ],
num_rows: 2261
})
})
the above is my dataset format
from trl import setup_chat_format
model,tokenizer = setup_chat_format(model,tokenizer)
args = SFTConfig(
output_dir = âlora_model/â,
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
learning_rate = 2e-05,
gradient_accumulation_steps = 2,
max_steps = 300,
logging_strategy = âstepsâ,
logging_steps = 25,
save_strategy = âstepsâ,
save_steps = 25,
eval_strategy = âstepsâ,
eval_steps = 25,
fp16 = True,
data_seed=42,
max_seq_length = 2048,
dataset_text_field = âmessagesâ
)
trainer = SFTTrainer(
model = model,
args = args,
processing_class = tokenizer,
train_dataset = dataset[âtrainâ],
eval_dataset = dataset[âtestâ])
im performing fine tuning using sft, im getting error when i run below cell
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-31-9e22427f7dd7> in <cell line: 1>()
----> 1 trainer = SFTTrainer(
2 model = model,
3 args = args,
4 processing_class = tokenizer,
5 train_dataset = dataset['train'],
/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py in wrapped_func(*args, **kwargs)
163 warnings.warn(message, FutureWarning, stacklevel=2)
164
--> 165 return func(*args, **kwargs)
166
167 return wrapped_func
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, processing_class, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics, peft_config, formatting_func)
196 preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
197 if preprocess_dataset:
--> 198 train_dataset = self._prepare_dataset(
199 train_dataset, processing_class, args, args.packing, formatting_func, "train"
200 )
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in _prepare_dataset(self, dataset, processing_class, args, packing, formatting_func, dataset_name)
409 if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
410 map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
--> 411 dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
412
413 # Pack or truncate
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
558 "columns": self._format_columns,
559 "output_all_columns": self._output_all_columns,
--> 560 }
561 # apply actual function
562 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3071 except NonExistentDatasetError:
3072 pass
-> 3073 if transformed_dataset is None:
3074 with hf_tqdm(
3075 unit=" examples",
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in _map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
3444 tasks.append(loop.create_task(async_apply_function(example, i, offset=offset)))
3445 # keep the total active tasks under a certain number
-> 3446 if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
3447 done, pending = loop.run_until_complete(
3448 asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
3336 effective_indices = [i + offset for i in indices] if isinstance(indices, list) else indices + offset
3337 additional_args = ()
-> 3338 if with_indices:
3339 additional_args += (effective_indices,)
3340 if with_rank:
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in <lambda>(ex)
409 if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
410 map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
--> 411 dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
412
413 # Pack or truncate
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in __getitem__(self, key)
275
276 def __len__(self):
--> 277 return len(self.data)
278
279 def __getitem__(self, key):
KeyError: 'messages'