Model Selection to convert Prompt to Json Object

I want to fine tune a model which can take a String into input and generate output for values of a specific parameters.

Eg:

Prompt: find all high priority tickets assigned to John

Output: {assignee: “John”, priority: “High”}

1 Like

Combining a smaller LLM with structured output makes implementation easy. For higher speed, fine-tuning an embedding model might be a good approach.

# pip install -q -U transformers pydantic lm-format-enforcer accelerate
import os, json, torch
from typing import Literal
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn

# Pydantic v2 schema
class Ticket(BaseModel):
    assignee: str
    priority: Literal["Low", "Medium", "High"]

schema = Ticket.model_json_schema()

# Instruct model
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")

# Constrained decoding
parser = JsonSchemaParser(schema)
prefix_fn = build_transformers_prefix_allowed_tokens_fn(tok, parser)

@torch.inference_mode()
def infer(message):
    # Build chat prompt as TEXT, then tokenize to get a MAPPING
    messages = [
        {"role": "system", "content": "Return only valid minified JSON matching the schema."},
        {"role": "user", "content": message},
    ]
    chat_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)  # string
    enc = tok(chat_text, return_tensors="pt").to(model.device)  # dict with input_ids + attention_mask

    out = model.generate(
        **enc,
        max_new_tokens=64,
        do_sample=False,
        num_beams=1,
        pad_token_id=tok.eos_token_id,
        eos_token_id=tok.eos_token_id,
        prefix_allowed_tokens_fn=prefix_fn,
    )
    gen_ids = out[:, enc["input_ids"].shape[-1]:]
    json_str = tok.batch_decode(gen_ids, skip_special_tokens=True)[0]
    json_dict = json.loads(json_str)
    return json_dict

print(infer("find all high priority tickets assigned to John")) # {'priority': 'High', 'assignee': 'John'}
print(infer("find Low priority tickets assigned to Smith")) # {'priority': 'Low', 'assignee': 'Smith'}