QLoRA trained Mixtral 8x7B deployment error on Sagemaker using text generation inference image

Hi ,

I was able to fine-tune successfully the Mixtral 8x7B QLoRA model following the instructions on this:
https://github.com/brevdev/notebooks/blob/main/mixtral-finetune.ipynb

After this I did:

model = model.merge_and_unload()
model.push_to_hub('mixtral_merged', private=True)

But when I try to deploy with:

role = sagemaker.get_execution_role()
llm_image = '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04-v1.0'

# sagemaker config
instance_type = "ml.g5.4xlarge"
number_of_gpu = 1
health_check_timeout = 300

# Define Model and Endpoint configuration parameter
config = {
  'HF_MODEL_ID': "brunolcb/mixtral_merged",
  'SM_NUM_GPUS': json.dumps(1), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(4096),  # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(8192),  # Max length of the generation (including input text)
}

# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
  role=role,
  image_uri=llm_image,
  env=config,
)

predictor = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  endpoint_name = "EndpointMixtral",
  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model
)

I got the error:

#033[2m2024-04-10T18:51:40.215660Z#033[0m #033[31mERROR#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Error when initializing model

Traceback (most recent call last):
File “/opt/conda/bin/text-generation-server”, line 8, in
sys.exit(app())
File “/opt/conda/lib/python3.10/site-packages/typer/main.py”, line 311, in call
return get_command(self)(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1157, in call
return self.main(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/typer/core.py”, line 778, in main
return _main(
File “/opt/conda/lib/python3.10/site-packages/typer/core.py”, line 216, in _main
rv = self.invoke(ctx)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 783, in invoke
return __callback(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/typer/main.py”, line 683, in wrapper
return callback(**use_params) # type: ignore
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py”, line 89, in serve
server.serve(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py”, line 228, in serve
asyncio.run(
File “/opt/conda/lib/python3.10/asyncio/runners.py”, line 44, in run
return loop.run_until_complete(main)
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 636, in run_until_complete
self.run_forever()
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 603, in run_forever
self._run_once()
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 1909, in _run_once
handle._run()
File “/opt/conda/lib/python3.10/asyncio/events.py”, line 80, in _run
self._context.run(self._callback, *self._args)

File “/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py”, line 174, in serve_inner
model = get_model(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/init.py”, line 310, in get_model
return FlashMixtral(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_mixtral.py”, line 21, in init
super(FlashMixtral, self).init(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_mistral.py”, line 333, in init
model = model_cls(config, weights)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 820, in init
self.model = MixtralModel(config, weights)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 757, in init
[
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 758, in
MixtralLayer(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 692, in init
self.self_attn = MixtralAttention(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 232, in init
self.query_key_value = load_attention(config, prefix, weights)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 125, in load_attention
return _load_gqa(config, prefix, weights)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py”, line 152, in _load_gqa
assert list(weight.shape) == [

AssertionError: [12582912, 1] != [6144, 4096]

How can I solve this, and deploy mixtral QLoRA merged weights on Amazon Sagemaker?