Model won't load on custom inference endpoint

Hi, I’ve finetuned a llama3 model using unsloth and saved the lora weights to a HF repo.
I’ve used this lora in colab and on local with no issue but when trying to make a handler for it, I get errors about class and model weights decrepencies.

Here are the error messages :

• 2024/06/12 12:00:39
ERROR: Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
  File "/opt/conda/lib/python3.10/site-packages/typer/", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/", line 683, in wrapper
    return callback(**use_params) # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/", line 90, in serve
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/", line 253, in serve
  File "/opt/conda/lib/python3.10/asyncio/", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/", line 636, in run_until_complete
  File "/opt/conda/lib/python3.10/asyncio/", line 603, in run_forever
  File "/opt/conda/lib/python3.10/asyncio/", line 1909, in _run_once
  File "/opt/conda/lib/python3.10/asyncio/", line 80, in _run, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/", line 217, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/", line 333, in get_model
    return FlashLlama(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/", line 84, in __init__
    model = FlashLlamaForCausalLM(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 385, in __init__
    self.model = FlashLlamaModel(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 309, in __init__
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 310, in <listcomp>
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 249, in __init__
    self.self_attn = FlashLlamaAttention(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 126, in __init__
    self.query_key_value = load_attention(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 43, in load_attention
    return _load_gqa(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/", line 85, in _load_gqa
    assert list(weight.shape) == [
AssertionError: [12582912, 1] != [6144, 4096]

This is the last attempt I made by removing unsloth, I also tried loading with unsloth but I had the same issue.

from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os
import torch
from subprocess import run
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

run("pip3 install torch torchvision torchaudio --index-url", shell=True, check=True)

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo
        # self.model = load_model(path)
        self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")

        print("loading model")
        tokenizer = AutoTokenizer.from_pretrained(path, token=self.HF_READ_TOKEN)
        model= AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path = path,
            token = self.HF_READ_TOKEN,
        self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
        self.alpaca_prompt = """REDACTED"""
        print("model loaded")
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            A :obj:`list` | `dict`: will be serialized and returned

        # pseudo
        # self.model(input)
        if data["input"] is not Null:
            request = data.pop("input",data)
            inputs = self.alpaca_prompt.format(request)
            prediction = self.pipeline(inputs)
            return {"prediction": prediction}
            return [{"Error" : "no input received."}]

I managed to fix this issue by ensuring specific versions of TRL and peft were installed ! also the hf_hub did have a version that was not working.
here’s the requirements.txt for anyone that can be having the same issue in the future :

unsloth[colab-new] @ git+
