Model won't load on custom inference endpoint

Hi, I’ve finetuned a llama3 model using unsloth and saved the lora weights to a HF repo.
I’ve used this lora in colab and on local with no issue but when trying to make a handler for it, I get errors about class and model weights decrepencies.

Here are the error messages :

• 2024/06/12 12:00:39
ERROR: Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params) # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 253, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 217, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 333, in get_model
    return FlashLlama(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_llama.py", line 84, in __init__
    model = FlashLlamaForCausalLM(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 385, in __init__
    self.model = FlashLlamaModel(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 309, in __init__
    [
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 310, in <listcomp>
    FlashLlamaLayer(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 249, in __init__
    self.self_attn = FlashLlamaAttention(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 126, in __init__
    self.query_key_value = load_attention(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 43, in load_attention
    return _load_gqa(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 85, in _load_gqa
    assert list(weight.shape) == [
AssertionError: [12582912, 1] != [6144, 4096]

• 2024/06/12 12:00:40
ERROR: Shard complete standard error output:

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'.
The class this function is called from is 'LlamaTokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:658: UserWarning: You are using a Backend <class 'text_generation_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
  warnings.warn(
Traceback (most recent call last):

 File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
    server.serve(

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 253, in serve
    asyncio.run(

 File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

 File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 217, in serve_inner
    model = get_model(

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 333, in /opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 333, in get_model
    return FlashLlama(

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_llama.py", line 84, in __init__
    model = FlashLlamaForCausalLM(prefix, config, weights)

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 385, in __init__
    self.model = FlashLlamaModel(prefix, config, weights)

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 309, in __init__
    [

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 310, in \n\n <listcomp>\n FlashLlamaLayer(\n\n File \"/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py\"

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 249, in __init__
    self.self_attn = FlashLlamaAttention(

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 126, in __init__
    self.query_key_value = load_attention(config, prefix, weights)

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 43, in load_attention
    return _load_gqa(config, prefix, weights)

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 85, in _load_gqa
    assert list(weight.shape) == [

AssertionError: [12582912, 1] != [6144, 4096]

This is the last attempt I made by removing unsloth, I also tried loading with unsloth but I had the same issue.

from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os
import torch
from subprocess import run
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

run("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121", shell=True, check=True)

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo
        # self.model = load_model(path)
        self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")

        print("loading model")
        
        tokenizer = AutoTokenizer.from_pretrained(path, token=self.HF_READ_TOKEN)
        
        model= AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path = path,
            token = self.HF_READ_TOKEN,
            torch_dtype=dtype,
        ).to(device)
        
        self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
        self.alpaca_prompt = """REDACTED"""
        print("model loaded")
        
        
 
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        # pseudo
        # self.model(input)
        if data["input"] is not Null:
            request = data.pop("input",data)
            inputs = self.alpaca_prompt.format(request)
            prediction = self.pipeline(inputs)
            return {"prediction": prediction}
        else:
            return [{"Error" : "no input received."}]

I managed to fix this issue by ensuring specific versions of TRL and peft were installed ! also the hf_hub did have a version that was not working.
here’s the requirements.txt for anyone that can be having the same issue in the future :

huggingface_hub
unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
xformers
trl<0.9.0 
peft==0.11.1
bitsandbytes
transformers
1 Like