Using, saving, and loading an encapsulated PEFT-tuned model inside a regular pytorch model

Hi everyone,

I am willing to finetune a specific model for Sentence Classification : BAAI/bge-multilingual-gemma2 and I have multiple question reguarding the usage of PEFT (after having read documentation and different forum posts)

I want to apply PEFT to this base model to extract embeddings (using it as a feature extractor) and warp it in a PyTorch Module wich will contain a classification head (among other things). And then start finetuning.

Here is my code:

[...]

#---------------------------------------------------------------------------------
def train_model(model, dataloader, valid_dataloader, optimizer, scheduler = None, num_epochs=5, device="cuda"):
    model = model.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        model.train()
        
        for batch in tqdm(dataloader, total=len(dataloader), unit='row'):
            optimizer.zero_grad()
            
            logits = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_mask'].to(device)
            )
            
            # One-hot labels
            labels = batch['label'].to(device)
        
            loss = nn.CrossEntropyLoss()(logits, labels)

            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            
            total_loss += loss.item()
            
        
        metrics = evaluate_model(model, valid_dataloader, device=device)
     
        print(f"Trainning Epoch {epoch + 1}, Accumulated Train Loss: {total_loss / len(dataloader)}")
        print(f"Eval : Valid Loss: {metrics['loss']}, Valid Accuracy : {metrics['accuracy']}"

#-------------------------------------------------------------------
class PreferencePredictionModel(nn.Module):
    def __init__(self, gemma_model, num_classes=2):
        super(PreferencePredictionModel, self).__init__()
        
        # Load transformer model
        self.gemma_model = gemma_model
        transformer_hidden_size = gemma_model.config.hidden_size
        
        # Fully connected layers for features
        #self.feature_fc = nn.Linear(feature_dim, 64)
        
        # Final classification layer
        self.classifier = nn.Sequential(
            #nn.Linear(transformer_hidden_size + 64, 128),
            nn.Linear(transformer_hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, input_ids, attention_mask, features=None):
        outputs = self.gemma_model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = last_token_pool(outputs.last_hidden_state, attention_mask)
        
        # normalize embeddings ????
        #embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Feature processing
        #feature_output = self.feature_fc(features)
        
        # Concatenate and classify
        combined = embeddings
        logits = self.classifier(combined)
        
        return logits

#---------------------------------------------------------------------------------
[...]

lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj"],
    #layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.FEATURE_EXTRACTION, #SEQ_CLS
)

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModel.from_pretrained('BAAI/bge-multilingual-gemma2', 
            torch_dtype=torch.float16, 
            device_map="auto", 
            quantization_config=quantization_config
            )

model.config.use_cache = False

model = prepare_model_for_kbit_training(model)
lora_model = get_peft_model(model, lora_config)

predictionModel = PreferencePredictionModel(gemma_model=lora_model, num_classes=2)

optimizer = optim.Adam(predictionModel.parameters())

train_model(predictionModel, dataloader_train, dataloader_valid, optimizer, scheduler=None, device=device, num_epochs=config.n_epochs)

This actually runs but very slowly as i want to be able to save/load the model before uploading it for further trainning.

1st Question

Am i allowed to warp a PEFT-tuned model inside a Pytorch Module ? Do i have to use HuggingFace’s Trainner class to train my model or can i still use the classic custom batch looping function ?

2nd Question

I have seen multiple exemple of how to save/load a PEFT-tuned model saving and loading base model and adapters. But in this case i have tried to use pytorch saving method :

torch.save({
            'epoch': 0,
            'model_state_dict': predictionModel.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'../CheckPoints/PreferencePredictionModel.pt')

and loading it by recreating a mirror like model and then uploading checkpoints resulted in multiple errors :
‘Unexpected key(s) in state_dict: “gemma_model.base_model.model.layers.0.self_attn.q_proj.base_layer.weight.absmax”[…]’

Is there a way to properly save a warped PEFT model ?

Thank you

1 Like

1. Warping PEFT-tuned model inside a PyTorch Module:

Yes, you can wrap a PEFT-tuned model inside a custom PyTorch module. You don’t have to use Hugging Face’s Trainer class for training. It’s perfectly fine to use a custom training loop, as you’re doing with the train_model function. Hugging Face’s Trainer is a higher-level abstraction, but for fine-tuning PEFT models, a custom training loop provides more flexibility, as you can easily control the training process, the optimizer, and how the model is structured.

2. Saving/Loading a PEFT-tuned Model:

The error you’re encountering is likely due to how PEFT models are stored and loaded. When using PEFT, the model is effectively split into the base model and the adapter layers (PEFT-specific components). The base model is saved separately from the adapter layers, which is why you’re seeing unexpected keys in the state dictionary when trying to load the entire model using torch.save.

To properly save and load a PEFT-tuned model, you should save both the base model and the adapter layers. Here’s how you can do it:

Saving:

  1. Save the base model and adapters separately:
    # Save the base model (the pre-trained model)
    torch.save(model.base_model.state_dict(), 'base_model.pt')
    
    # Save the adapter layers (PEFT part)
    torch.save(lora_model.state_dict(), 'adapter_model.pt')
    
    # Optionally, save the optimizer and other states
    torch.save({
        'epoch': epoch,
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'optimizer_checkpoint.pt')
    

Loading:

  1. Load the base model and adapters separately:
    # Load the base model
    model = AutoModel.from_pretrained('BAAI/bge-multilingual-gemma2')
    model.load_state_dict(torch.load('base_model.pt'))
    
    # Load the adapter layers (PEFT part)
    lora_model = get_peft_model(model, lora_config)
    lora_model.load_state_dict(torch.load('adapter_model.pt'))
    
    # Optionally, load the optimizer and other states
    checkpoint = torch.load('optimizer_checkpoint.pt')
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    

This method ensures you properly handle the PEFT model by keeping the base model and adapters separate, avoiding issues with unexpected keys in the state dictionary.

2 Likes

Hi @Alanturner2 , thanks for your quick response. I’m happy to see you are not shocked by my “architecture”.

Reguarding the save/load part, i was first i bit confused about the saving optimizer checkpoint thing :

# Optionally, load the optimizer and other states
checkpoint = torch.load('optimizer_checkpoint.pt')
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Because my original intention was to save my model to make inference in another notebook (on kaggle), so I had no will to reuse the optimizer for further trainning.

However, I manage to use your code sample and instead of the optimizer, i saved manually the classification head on my custom model :

#-------------------------------------------------------------------
def custom_save_model_chkpt(model, checkpointName, optimizer=None):
    # peft model
    model.gemma_model.save_pretrained(f'../Checkpoints/{checkpointName}/PEFT-bge-multilingual-gemma2', save_adapters=True, save_embedding_layers=True)
    
    # features and classifier
    torch.save({
        'epoch': 0,
        #'optimizer_state_dict': optimizer.state_dict(),
        #'feature_fc_state_dict', predictionModel_original.fc.state_dict()
        'classifier_state_dict': model.classifier.state_dict(),
        }, f'../Checkpoints/{checkpointName}/PreferencePredictionModel.pt')

#-------------------------------------------------------------------
def custom_load_model_chkpt(baseModelPath, checkpointName, quantization_config=None, optimizer=None):
    # load base
    baseModel = AutoModel.from_pretrained(
            baseModelPath,
            torch_dtype=torch.float16,
            quantization_config=quantization_config
            )

    baseModel = prepare_model_for_kbit_training(baseModel)
    
    # load peft from base
    loraModel_load = PeftModel.from_pretrained(
            baseModel,
            #torch_dtype=torch.float16,
            f'../Checkpoints/{checkpointName}/PEFT-bge-multilingual-gemma2',
            is_trainable=True
            )
    
    predictionModelLoaded = PreferencePredictionModel(loraModel_load, feature_dim=4, num_classes=2)
    
    checkpoint = torch.load(f'../Checkpoints/{checkpointName}/PreferencePredictionModel.pt')
    
    predictionModelLoaded.classifier.load_state_dict(checkpoint['classifier_state_dict'])
    
    return predictionModelLoaded

After few tests, errors, and fixes, this runs like a charm. I haven’t checked if weights were corrects, but i will do it later and update this post if this solution doesn’t work properlly in the end.

Thank you for your help, love the community!

Cheers

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.