Getting started with Voxtral for ASR transcription

I am trying to execute Voxtral the default example for transcription of the obama speech for ASR of Voxtral.

Generated responses:

This

How can this be changed so the real/full text is returned - not just the first word.

import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor, infer_device

device = infer_device()
repo_id = "mistralai/Voxtral-Mini-3B-2507"

processor = AutoProcessor.from_pretrained(repo_id)
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)

inputs = processor.apply_transcription_request(language="en", audio="https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3", model_id=repo_id)
inputs = inputs.to(device, dtype=torch.bfloat16)

outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

print("\nGenerated responses:")
print("=" * 80)
for decoded_output in decoded_outputs:
    print(decoded_output)
    print("=" * 80)


1 Like

I think this is a bfloat 16 mixup with MPS

import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor

device = "mps" if torch.backends.mps.is_available() else "cpu"
repo_id = "mistralai/Voxtral-Mini-3B-2507"
audio_url = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3"

processor = AutoProcessor.from_pretrained(repo_id)

# ⚠️ Use fp16 on MPS (avoid bf16). Also force eager attention on MPS for correctness.
model = VoxtralForConditionalGeneration.from_pretrained(
    repo_id,
    torch_dtype=torch.float16 if device == "mps" else torch.float32,
    attn_implementation="eager",          # helps avoid MPS SDPA quirks
    device_map={"": device},              # single-device map; no auto-sharding on MPS
)

# Build the transcription request
inputs = processor.apply_transcription_request(
    language="en", audio=audio_url, model_id=repo_id
)

# Move to device and cast only floating tensors to fp16 on MPS
inputs = inputs.to(device)               # move first
for k, v in list(inputs.items()):
    if torch.is_tensor(v) and torch.is_floating_point(v) and device == "mps":
        inputs[k] = v.to(dtype=torch.float16)

# Greedy is fine for transcription; raise the budget for a ~5 min clip
outputs = model.generate(**inputs, max_new_tokens=2048, do_sample=False)

decoded = processor.batch_decode(
    outputs[:, inputs.input_ids.shape[1]:],
    skip_special_tokens=True
)

print("\nGenerated responses:\n" + "="*80)
for d in decoded:
    print(d)
    print("="*80)

fixes things for me

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.