Need Help fine tuning llama3 with torchtune

Hi, I’m trying to LoRA fine tune llama3 with my own dataset, using torchtune. I’ve managed to generate a fine tuned model using the cli, but can’t seem to get the “tune run generate” command to work.
This is the error I’m getting:

Traceback (most recent call last):
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/models/convert_weights.py", line 57, in _get_mapped_key
    new_key = mapping_dict[key]
KeyError: 'tok_embeddings.weight'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/.local/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/_cli/run.py", line 179, in _run_cmd
    self._run_single_device(args)
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/usr/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/usr/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/user/.local/lib/python3.10/site-packages/recipes/generate.py", line 152, in <module>
    sys.exit(main())
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/home/user/.local/lib/python3.10/site-packages/recipes/generate.py", line 147, in main
    recipe.setup(cfg=cfg)
  File "/home/user/.local/lib/python3.10/site-packages/recipes/generate.py", line 47, in setup
    ckpt_dict = checkpointer.load_checkpoint()
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/utils/_checkpointing/_checkpointer.py", line 384, in load_checkpoint
    converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/models/convert_weights.py", line 152, in hf_to_tune
    new_key = _get_mapped_key(key, _FROM_HF)
  File "/home/user/.local/lib/python3.10/site-packages/torchtune/models/convert_weights.py", line 59, in _get_mapped_key
    raise Exception(
Exception: Error converting the state dict. Found unexpected key: "tok_embeddings.weight". Please make sure you're loading a checkpoint with the right format.

I’ve tried using different checkpointers with this command. Using FullModelHFCheckpointer results in the error above. Using FullModelMetaCheckpointer or FullModelTorchTuneCheckpointer results in the same message, but with a much longer list of unexpected and required keys. I also noticed that all the tutorials in the docs use FullModelHFCheckpointer, so I’m assuming it’s the correct one.

This is the config file I used while running “tune run lora_finetune_single_device”:

model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  lora_attn_modules: ['q_proj', 'v_proj']
  apply_lora_to_mlp: False
  apply_lora_to_output: False
  lora_rank: 8
  lora_alpha: 16

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: ./model/original/tokenizer.model

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: ./model/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: /data/Meta-Llama-3-8B-fine-tuned/
  model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
  _component_: custom_dataset_prep.prep_custom_dataset
  tokenizer: torchtune.models.llama3.llama3_tokenizer
  source: json
  data_files: my_custom_fine_tuning_data.json
  column_map:
    dialogue: prompt
    output: response
  max_seq_len: 1024
  train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torch.nn.CrossEntropyLoss

# Training
epochs: 2
max_steps_per_epoch: null
gradient_accumulation_steps: 64
compile: False

# Logging
output_dir: /data/lora_finetune_output
metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: ${output_dir}
log_every_n_steps: null

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
  _component_: torchtune.utils.profiler
  enabled: False

generating the model works without any error messages in the console or log file

This is the config file I’m using to run “tune run generate”:

model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  lora_attn_modules: ['q_proj', 'v_proj']
  # apply_lora_to_mlp: False
  # apply_lora_to_output: False
  # lora_rank: 8
  # lora_alpha: 16

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /data/Meta-Llama-3-8B-fine-tuned/
  checkpoint_files: [
    meta_model_1.pt,
  ]
  # adapter_checkpoint: /data/Meta-Llama-3-8B-fine-tuned/adapter_1.pt
  # recipe_checkpoint: /data/Meta-Llama-3-8B-fine-tuned/recipe_state.pt
  output_dir: /data/Meta-Llama-3-8B-fine-tuned/eval/
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: ./model/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Hello?"
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

quantizer: null

(commenting in the commented out parameters doesn’t change anything about the output)

I’m assuming I messed up the configs somehow, but in case it’s relevant, here is the “custom_dataset_prep” script from the finetuning config:

from typing import Optional, Mapping, Any, List, Dict
from torchtune.datasets import InstructDataset
from torchtune.data import Message, InstructTemplate
from torchtune.modules.tokenizers import Tokenizer

class ChatTemplate(InstructTemplate):
    template = "{dialogue}"

    @classmethod
    def format(
        cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None
    ) -> str:
        column_map = column_map or {}
        key_dialogue = column_map.get("dialogue", "dialogue")

        prompt = cls.template.format(dialogue=sample[key_dialogue])
        return prompt


def prep_custom_dataset(
    tokenizer: Tokenizer,
    source: str,
    data_files: str,
    column_map: Optional[Dict[str, str]],
    max_seq_len: int = 1024,
    train_on_input: bool = True,
) -> InstructDataset:
    print(column_map)
    return InstructDataset(
        tokenizer=tokenizer,
        source=source,
        template=ChatTemplate(),
        transform = None,
        column_map=column_map,
        train_on_input=train_on_input,
        max_seq_len=max_seq_len,
        data_files=data_files,
        split="train",
    )