import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor
device = "mps" if torch.backends.mps.is_available() else "cpu"
repo_id = "mistralai/Voxtral-Mini-3B-2507"
audio_url = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3"
processor = AutoProcessor.from_pretrained(repo_id)
# ⚠️ Use fp16 on MPS (avoid bf16). Also force eager attention on MPS for correctness.
model = VoxtralForConditionalGeneration.from_pretrained(
repo_id,
torch_dtype=torch.float16 if device == "mps" else torch.float32,
attn_implementation="eager", # helps avoid MPS SDPA quirks
device_map={"": device}, # single-device map; no auto-sharding on MPS
)
# Build the transcription request
inputs = processor.apply_transcription_request(
language="en", audio=audio_url, model_id=repo_id
)
# Move to device and cast only floating tensors to fp16 on MPS
inputs = inputs.to(device) # move first
for k, v in list(inputs.items()):
if torch.is_tensor(v) and torch.is_floating_point(v) and device == "mps":
inputs[k] = v.to(dtype=torch.float16)
# Greedy is fine for transcription; raise the budget for a ~5 min clip
outputs = model.generate(**inputs, max_new_tokens=2048, do_sample=False)
decoded = processor.batch_decode(
outputs[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
print("\nGenerated responses:\n" + "="*80)
for d in decoded:
print(d)
print("="*80)