How to Fine-Tune Phi3-Vision Model with LoRA for Recognizing UI Elements in Images?

Hello Hugging Face Community,

I’m working on a project to fine-tune the Phi3-Vision model to recognize UI elements such as buttons, posters, and icons in images, especially those related to streaming services. I have a few questions before diving into the implementation:

  1. Should I focus on fine-tuning only the language part of the model, or is it beneficial to fine-tune the vision part as well?
  2. How can I effectively use Low-Rank Adaptation (LoRA) to fine-tune this model on a local machine?

I found an example code snippet that outlines the steps for this process, but I haven’t implemented it yet. Before proceeding, I want to ensure that this approach is sound and to get any additional recommendations or insights from the community.

Here is the example code snippet for reference:

import torch
from transformers import Trainer, TrainingArguments
import loralib as lora
from transformers import AutoModelForVision
from PIL import Image

# Load the model
model = AutoModelForVision.from_pretrained("microsoft/Phi-3-vision-128k-instruct")

# Modify layers for LoRA
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        setattr(model, name, lora.Linear(module.in_features, module.out_features, r=16))

# Mark LoRA parameters as trainable
lora.mark_only_lora_as_trainable(model)

# Custom Data Collator
class CustomDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            question = "sample question text"
            answer = "sample answer text"
            INST_PREFIX = 'sample instruction prefix'
            messages = [
                {"role": "user", "content": f"\n{INST_PREFIX} {question}"},
                {"role": "assistant", "content": answer}
            ]
            text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            texts.append(text)
            image_file = Image.open(f"{DATASET_DIR}/{example['image']}")
            images.append(image_file)
        batch = self.processor(texts, images, return_tensors="pt")
        labels = batch["input_ids"].clone()
        batch["labels"] = labels
        return batch

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_loader,
    eval_dataset=val_loader,
)

# Train the model
trainer.train()

# Save the model
torch.save(lora.lora_state_dict(model), 'lora_checkpoint.pt')

I appreciate any feedback or suggestions on this approach. Thank you in advance for your help!

Best regards