Cannot export tflite using optimum for a fine-tuned gemma 3 model for task : question answering

Hello, I was able to overcome this problem by making some changes to the codes prepared by Google in colab. You can convert it directly to tflite and .task format later if desired. Instead of fine-tuning from the beginning, I used the already trained model.

import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

!pip3 install --upgrade -q -U bitsandbytes
!pip3 install --upgrade -q -U peft
!pip3 install --upgrade -q -U trl
!pip3 install --upgrade -q -U accelerate
!pip3 install --upgrade -q -U datasets
!pip3 install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

! pip install git+https://github.com/google-ai-edge/ai-edge-torch
! pip install ai-edge-litert
! pip install mediapipe


!pip install huggingface_hub
from huggingface_hub import snapshot_download
import shutil

# 🔧 Ayarları yap
model_name = "username/model_repo"  # ← fine-tuned model
local_dir = "/content/merged_model"

# 💾 Hugging Face
snapshot_download(
    repo_id=model_name,
    local_dir=local_dir,
    local_dir_use_symlinks=False  # # Don't bother with symlink, just copy it directly
)

print(f"Model downloaded: {local_dir}")

!git clone https://github.com/google-ai-edge/ai-edge-torch.git

!pip uninstall numpy
!pip uninstall torch torchvision torchaudio
!pip uninstall ai-edge-torch ai-edge-litert ai-edge-quantizer torch-xla2 safetensors
!pip install numpy
!pip install torch torchvision torchaudio
!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt

!pip install --upgrade numpy
!pip install --upgrade --force-reinstall ai-edge-torch
!pip install --upgrade --force-reinstall ai-edge-litert
!pip install --upgrade --force-reinstall ai-edge-quantizer
!pip install --upgrade --force-reinstall torch-xla2
!pip install --upgrade --force-reinstall safetensors


from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
from ai_edge_torch.generative.layers.experimental import kv_cache
import torch


def _create_mask(mask_len, kv_cache_max_len):
    mask = torch.full((mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32)
    return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)

def _create_export_config(prefill_seq_lens: list[int], kv_cache_max_len: int) -> ExportConfig:
    export_config = ExportConfig()
    export_config.prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
    decode_mask = torch.full((1, kv_cache_max_len), float('-inf'), dtype=torch.float32)
    export_config.decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
    export_config.kvcache_cls = kv_cache.KVCacheTransposed
    return export_config


with torch.inference_mode(True):
    checkpoint_path = "/content/merged_model"
    pytorch_model = gemma3.build_model_1b(
        checkpoint_path, kv_cache_max_len=2048
    )

    export_config = _create_export_config([1024], 2048)

    converter.convert_to_tflite(
        pytorch_model,
        output_path="/content/",
        output_name_prefix="gemma3_1b_finetune",
        prefill_seq_len=[1024],
        quantize=True,
        lora_ranks=None,
        export_config=export_config
    )


from mediapipe.tasks.python.genai.bundler import llm_bundler

def build_task_bundle():
    config = llm_bundler.BundleConfig(
        tflite_model="/content/gemma3_1b_finetune_q8_ekv2048.tflite",
        tokenizer_model="/content/merged_model/tokenizer.model",
        start_token="<bos>",
        stop_tokens=["<eos>", "<end_of_turn>"],
        output_filename="/content/gemma3-1b-it.task",
        enable_bytes_to_unicode_mapping=False,
        prompt_prefix="<start_of_turn>user\n",
        prompt_suffix="<end_of_turn>\n<start_of_turn>model\n",
    )
    llm_bundler.create_bundle(config)

build_task_bundle()




1 Like