I want to fine tune a model which can take a String into input and generate output for values of a specific parameters.
Eg:
Prompt: find all high priority tickets assigned to John
Output: {assignee: “John”, priority: “High”}
I want to fine tune a model which can take a String into input and generate output for values of a specific parameters.
Eg:
Prompt: find all high priority tickets assigned to John
Output: {assignee: “John”, priority: “High”}
Combining a smaller LLM with structured output makes implementation easy. For higher speed, fine-tuning an embedding model might be a good approach.
# pip install -q -U transformers pydantic lm-format-enforcer accelerate
import os, json, torch
from typing import Literal
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
# Pydantic v2 schema
class Ticket(BaseModel):
assignee: str
priority: Literal["Low", "Medium", "High"]
schema = Ticket.model_json_schema()
# Instruct model
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
# Constrained decoding
parser = JsonSchemaParser(schema)
prefix_fn = build_transformers_prefix_allowed_tokens_fn(tok, parser)
@torch.inference_mode()
def infer(message):
# Build chat prompt as TEXT, then tokenize to get a MAPPING
messages = [
{"role": "system", "content": "Return only valid minified JSON matching the schema."},
{"role": "user", "content": message},
]
chat_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # string
enc = tok(chat_text, return_tensors="pt").to(model.device) # dict with input_ids + attention_mask
out = model.generate(
**enc,
max_new_tokens=64,
do_sample=False,
num_beams=1,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
prefix_allowed_tokens_fn=prefix_fn,
)
gen_ids = out[:, enc["input_ids"].shape[-1]:]
json_str = tok.batch_decode(gen_ids, skip_special_tokens=True)[0]
json_dict = json.loads(json_str)
return json_dict
print(infer("find all high priority tickets assigned to John")) # {'priority': 'High', 'assignee': 'John'}
print(infer("find Low priority tickets assigned to Smith")) # {'priority': 'Low', 'assignee': 'Smith'}