Unable to load a saved custom model

Hi folks …i added a new embedding layer to a llama2-hf model and saved it. Saved easily enough but loading the model using from_pretrained gives the following error
ValueError: The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved

PFB the code for the custom class and the line i used to save the model

class CustomLlama2Model(PreTrainedModel):
    def __init__(self, config):
        super(CustomLlama2Model, self).__init__( config )

        self.llama_peft_model = get_peft_model( llama2_hf, llama_peft_config)

        self.embed_save_path = '../../CLR/DISTIL_FINE_TUNE'
        self.embed_model = torch.load(self.embed_save_path)
        self.embed_model.eval()

        self.last_hidden_sz_llama2_, self.embedding_sz_ = 4096, 768

        self.lm_head = torch.nn.Linear( self.last_hidden_sz_llama2_ + self.embedding_sz_, \
                                         llama2_hf.config.vocab_size, bias=False )


    def forward(self, input_ids, attention_mask, spatial_coords, labels):

        #print( 'input_ids->',input_ids.shape, input_ids )
        llama2_output = self.llama_peft_model( input_ids )
        additional_embedding_output = self.embed_model.spatial_embedding( spatial_coords )

        print('2 of a kind->', llama2_output['hidden_states'][-1].shape, additional_embedding_output.shape )
        # Combine the outputs as needed for your task
        # Example: concatenate the embeddings
        combined_output = torch.cat([llama2_output['hidden_states'][-1], additional_embedding_output], dim=-1)

        #return combined_output, labels
        ##NOTE - all below code comes from
        ##https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/models/llama/modeling_llama.py#L847C29-L847C29
        logits = self.lm_head( combined_output )
        ## pass it via a linear or a couple of lm heads to reduce size to self.llama2_hf.config.vocab_size
        logits = logits.float()


        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            #print( 'LOGITS->', torch.argmax( shift_logits, dim=-1 ) )
            #print( 'LABELS->', ( shift_labels ) )
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view( -1, llama2_hf.config.vocab_size )
            shift_labels = shift_labels.view( -1 )
            # Enable model parallelism
            shift_labels = shift_labels.to( shift_logits.device )
            loss = loss_fct( shift_logits, shift_labels )

        print('Loss RETURN->', loss)
        return ( loss, ) ## since trainer accepts loss as a tuple

and the below lines to invoke the model via the trainer

custom_model = CustomLlama2Model( llama2_hf.config )
trainer = transformers.Trainer(
    model=custom_model,
    #model=model,
    train_dataset=train_dataset,
    #train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=MICRO_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=100,
        num_train_epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        fp16=True,
        logging_steps=1,
        output_dir="custom-lora-dolly",
        save_total_limit=3,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train(resume_from_checkpoint=False)
#trainer.save_model( "custom-embedding-alpaca-lora-dolly-2.0" )
custom_model.save_pretrained( "custom-embedding-alpaca-lora-dolly-2.0" )

and then i load the model using

from transformers import AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
        load_in_8bit=False, load_in_4bit=True
    )

llama2_hf = AutoModelForCausalLM.from_pretrained("custom-embedding-alpaca-lora-dolly-2.0",
                                             device_map=torch.device("cpu"),
                                             quantization_config=quantization_config,
                                             torch_dtype=torch.bfloat16)

and then we hit the error i described above … please help !!