Issue while quantizing Gemma 4 E2B/E4B - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

Hello.
I’m trying to inference Gemma any-2-any models and quantize them. I faced with lot’s of different errors due to versioning and installing packages. I tried many Github solutions, but if I downgrade transformers then the AutoProcessor could not be found.

!pip uninstall -y torch torchvision torchaudio torchcodec -q
!pip install -U \
  "torch==2.11.0" \
  "torchvision==0.26.0" \
  "torchaudio==2.11.0" \
    torchcodec \
  --index-url https://download.pytorch.org/whl/cu128 -q

!pip install -U transformers  librosa accelerate bitsandbytes -q
!pip install soundfile -q
!pip install "datasets==2.21.0" -q
import os
import io

import numpy as np
import pandas as pd

import torch
import torchaudio

import soundfile as sf

import datasets
from datasets import load_dataset

from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import BitsAndBytesConfig

from transformers import logging
logging.set_verbosity(logging.CRITICAL)

import logging
logging.getLogger().setLevel(logging.ERROR)

import warnings
warnings.filterwarnings("ignore")
irish_male_dataset = load_dataset("ylacombe/english_dialects", "irish_male", split="train")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

MODEL_ID = "google/gemma-4-E2B-it"

processor = AutoProcessor.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
).eval()

def transcribe_batch(batch):
    transcriptions = []

    for i, f in enumerate(batch["audio"]):
        audio_path = f["path"]
        if audio_path and os.path.exists(audio_path):
            wav, sr = sf.read(io.BytesIO(f["bytes"]), dtype="float32")
        else:
            wav, sr = sf.read(io.BytesIO(f["bytes"]), dtype="float32")

        if wav.ndim > 1:
            wav = wav.mean(axis=1)  

        if sr != 16000:
            wav = torchaudio.functional.resample(
                torch.tensor(wav, dtype=torch.float32),  
                orig_freq=sr,
                new_freq=16000
            ).numpy()

        tmp_path = f"/tmp/tmp_audio_{i}.wav"
        sf.write(tmp_path, wav, 16000)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": tmp_path},
                    {"type": "text", "text": (
                        "Transcribe the following speech segment in English into English text. Don't generate additonal text."
                    )},
                ]
            }
        ]

        inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            add_generation_prompt=True,
        ).to(model.device)
        input_len = inputs["input_ids"].shape[-1]

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=64)

        response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
        transcriptions.append(response)
        torch.cuda.empty_cache()

    batch["transcribed_text"] = transcriptions
    print(transcriptions)
    return batch

irish_male_dataset = irish_male_dataset.cast_column("audio", datasets.features.Audio(decode=False))

processed_dataset = irish_male_dataset.map(
    transcribe_batch,
    batched=True,
    batch_size=32,
    remove_columns=["audio"],
    load_from_cache_file=False,
    desc="Transcribing Audios by Irish Men"
)

df = pd.DataFrame(processed_dataset)

df.to_csv("irish_male_transcribed_gemma-4-e2b-it.csv", index=False)

Here is the full error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_57/390005493.py in <cell line: 0>()
----> 1 processed_dataset = irish_male_dataset.map(
      2     transcribe_batch,
      3     batched=True,
      4     batch_size=32,
      5     remove_columns=["audio"],

/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    600             self: "Dataset" = kwargs.pop("self")
    601         # apply actual function
--> 602         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    603         datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    604         for dataset in datasets:

/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    565         }
    566         # apply actual function
--> 567         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    568         datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    569         # re-apply format to the output

/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3165                     desc=desc or "Map",
   3166                 ) as pbar:
-> 3167                     for rank, done, content in Dataset._map_single(**dataset_kwargs):
   3168                         if done:
   3169                             shards_done += 1

/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in _map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
   3556                         )  # Something simpler?
   3557                         try:
-> 3558                             batch = apply_function_on_filtered_inputs(
   3559                                 batch,
   3560                                 indices,

/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
   3425             if with_rank:
   3426                 additional_args += (rank,)
-> 3427             processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
   3428             if isinstance(processed_inputs, LazyDict):
   3429                 processed_inputs = {

/tmp/ipykernel_57/2288985515.py in transcribe_batch(batch)
     44 
     45         with torch.no_grad():
---> 46             outputs = model.generate(**inputs, max_new_tokens=64)
     47 
     48         response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)

/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    122         # pyrefly: ignore [bad-context-manager]
    123         with ctx_factory():
--> 124             return func(*args, **kwargs)
    125 
    126     return decorate_context

/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, custom_generate, **kwargs)
   2558 
   2559         # 9. Call generation mode
-> 2560         result = decoding_method(
   2561             self,
   2562             input_ids,

/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   2751 
   2752         prefill_consumed = False
-> 2753         outputs = self._prefill(
   2754             input_ids,
   2755             generation_config,

/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _prefill(self, input_ids, generation_config, model_kwargs, is_first_iteration)
   3797                 **model_kwargs,
   3798             )
-> 3799             return self(**model_inputs, return_dict=True)
   3800 
   3801         # Chunked prefill (for very large contexts)

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    898         if return_dict_passed is not None:
    899             return_dict = return_dict_passed
--> 900         output = func(self, *args, **kwargs)
    901         if not return_dict and not isinstance(output, tuple):
    902             output = output.to_tuple()

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, image_position_ids, video_position_ids, past_key_values, mm_token_type_ids, inputs_embeds, labels, use_cache, logits_to_keep, **kwargs)
   2534             Passed through to the vision encoder for positional embedding computation.
   2535         """
-> 2536         outputs = self.model(
   2537             input_ids=input_ids,
   2538             pixel_values=pixel_values,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    974                     output = func(self, *args, **kwargs)
    975             else:
--> 976                 output = func(self, *args, **kwargs)
    977         # Restore original config value
    978         finally:

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    898         if return_dict_passed is not None:
    899             return_dict = return_dict_passed
--> 900         output = func(self, *args, **kwargs)
    901         if not return_dict and not isinstance(output, tuple):
    902             output = output.to_tuple()

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, past_key_values, mm_token_type_ids, inputs_embeds, use_cache, image_position_ids, video_position_ids, **kwargs)
   2343         # Merge text and audio
   2344         if input_features is not None and input_features_mask is not None:
-> 2345             audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
   2346             audio_features = audio_output.pooler_output
   2347             audio_mask_from_encoder = audio_output.attention_mask  # True = valid

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    898         if return_dict_passed is not None:
    899             return_dict = return_dict_passed
--> 900         output = func(self, *args, **kwargs)
    901         if not return_dict and not isinstance(output, tuple):
    902             output = output.to_tuple()

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in get_audio_features(self, input_features, input_features_mask, **kwargs)
   2438             )
   2439 
-> 2440         audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
   2441         audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
   2442 

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    974                     output = func(self, *args, **kwargs)
    975             else:
--> 976                 output = func(self, *args, **kwargs)
    977         # Restore original config value
    978         finally:

/usr/local/lib/python3.12/dist-packages/transformers/utils/output_capturing.py in wrapper(self, *args, **kwargs)
    246             # Run the forward
    247             try:
--> 248                 outputs = func(self, *args, **kwargs)
    249             # Reset the states
    250             finally:

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_features, attention_mask, **kwargs)
   1965 
   1966         for encoder_layer in self.layers[: self.config.num_hidden_layers]:
-> 1967             hidden_states = encoder_layer(
   1968                 hidden_states,
   1969                 attention_mask=attention_mask,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, hidden_states, attention_mask, position_embeddings, **kwargs)
    544         gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
    545 
--> 546         hidden_states = self.feed_forward1(hidden_states)
    547         residual = hidden_states
    548 

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, hidden_states)
    425     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    426         # This is needed to avoid any underflow/overflow issues when clipping
--> 427         gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
    428 
    429         residual = hidden_states

TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

It appears to be a known, unresolved issue, but it seems possible to work around it by excluding the module causing the error during quantization.


Gemma 4 E2B/E4B + bitsandbytes 4-bit: torch.finfo() error on audio input

TL;DR

This looks like a Gemma 4 audio-module quantization issue, not a datasets.map() problem and not mainly an audio-decoding problem.

Your stack trace enters the Gemma 4 audio path:

model.generate(...)
  -> Gemma4 forward
    -> get_audio_features(...)
      -> audio_tower(...)
        -> Gemma4AudioFeedForward.forward(...)
          -> torch.finfo(self.ffw_layer_1.linear.weight.dtype).max

The most likely cause is:

  1. google/gemma-4-E2B-it / google/gemma-4-E4B-it have native audio support.
  2. Your prompt contains an audio input, so the model uses its audio_tower.
  3. bitsandbytes 4-bit quantization converts eligible linear layers, including audio layers unless told not to.
  4. A Linear4bit layer can expose packed quantized storage with dtype torch.uint8.
  5. Gemma 4 audio code calls torch.finfo(weight.dtype).
  6. torch.finfo() only works for floating-point dtypes, not torch.uint8.

So the practical fix is:

  • use AutoModelForMultimodalLM for Gemma 4 audio;
  • keep the main LLM body quantized in 4-bit;
  • skip quantization for audio_tower and embed_audio;
  • also manually skip lm_head, because of a separate skip-list pitfall in affected transformers versions.

Recommended loading pattern:

import torch
from transformers import AutoProcessor, AutoModelForMultimodalLM, BitsAndBytesConfig

MODEL_ID = "google/gemma-4-E2B-it"
# MODEL_ID = "google/gemma-4-E4B-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    llm_int8_skip_modules=[
        "lm_head",
        "audio_tower",
        "embed_audio",
        "model.audio_tower",
        "model.embed_audio",
    ],
)

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

model = AutoModelForMultimodalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.bfloat16,
    quantization_config=bnb_config,
).eval()

If your installed transformers version rejects dtype=..., use the older argument name:

model = AutoModelForMultimodalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
).eval()

Why this is not mainly a datasets.map() issue

The traceback starts inside:

processed_dataset = irish_male_dataset.map(...)

but the actual failure happens inside:

outputs = model.generate(**inputs, max_new_tokens=64)

datasets.map() is only the outer loop. It calls your transcribe_batch() function, and your function calls model.generate().

A more accurate summary of the traceback is:

datasets.map
  -> transcribe_batch
    -> model.generate
      -> Gemma 4 multimodal forward
        -> audio_tower
          -> Gemma4AudioFeedForward
            -> torch.finfo(torch.uint8)
              -> TypeError

So changing only these is unlikely to fix the root cause:

batch_size=32
remove_columns=["audio"]
load_from_cache_file=False
cast_column("audio", datasets.features.Audio(decode=False))

Those can affect memory, speed, caching, or audio decoding behavior, but they do not explain why model code is calling torch.finfo() on an integer dtype.


Why your code triggers the audio path

Gemma 4 E2B/E4B are multimodal models. Google’s Gemma overview says Gemma 4 supports text, audio, and image input, and Google DeepMind’s Gemma 4 page describes E2B/E4B as having audio and vision support for edge processing:

Your prompt contains:

{"type": "audio", "audio": tmp_path}

That causes the processor to produce audio inputs such as input_features and input_features_mask.

Your traceback confirms this:

if input_features is not None and input_features_mask is not None:
    audio_output = self.get_audio_features(...)

So the model is doing:

audio file
  -> processor
    -> input_features
      -> Gemma 4 audio_tower
        -> embed_audio
          -> language model
            -> generated transcript

For this path, the Transformers Gemma 4 docs use AutoModelForMultimodalLM, not AutoModelForCausalLM:

So I would replace:

from transformers import AutoProcessor, AutoModelForCausalLM

with:

from transformers import AutoProcessor, AutoModelForMultimodalLM

and replace:

model = AutoModelForCausalLM.from_pretrained(...)

with:

model = AutoModelForMultimodalLM.from_pretrained(...)

The key dtype trap: compute dtype is not storage dtype

Your config says:

bnb_4bit_compute_dtype=torch.bfloat16

That is useful, but it does not guarantee that every quantized module’s weight.dtype is torch.bfloat16.

bitsandbytes 4-bit layers can use packed quantized storage. The Linear4bit docs show quant_storage=torch.uint8 as the default:

So this can be true:

bnb_4bit_compute_dtype == torch.bfloat16

while this is also true inside a converted 4-bit module:

module.weight.dtype == torch.uint8

That is exactly why the Gemma 4 audio code can fail here:

torch.finfo(self.ffw_layer_1.linear.weight.dtype).max

because self.ffw_layer_1.linear.weight.dtype may be torch.uint8.


Why torch.finfo(torch.uint8) fails

PyTorch separates floating-point dtype metadata from integer dtype metadata:

  • torch.finfo(...) is for floating-point dtypes such as torch.float32, torch.float16, and torch.bfloat16.
  • torch.iinfo(...) is for integer dtypes such as torch.uint8, torch.int8, torch.int16, torch.int32, and torch.int64.

Relevant docs:

The error message is therefore technically correct:

TypeError: torch.finfo() requires a floating point input type.
Use torch.iinfo to handle 'torch.finfo'

However, in this model path, simply replacing finfo with iinfo is probably not the right semantic fix.

Why?

Because torch.uint8 here is likely packed quantized storage, not a meaningful floating-point activation dtype. The audio code wants a floating-point clamp/limit. torch.iinfo(torch.uint8).max == 255, but 255 is an integer storage bound, not a meaningful floating-point clipping bound for audio hidden states.

So the better fix is:

Do not quantize that audio module in the first place.

Closest matching upstream issue

There is a closely matching upstream transformers issue:

The issue reports this same core failure mode:

  • model: google/gemma-4-e2b-it
  • quantization: BitsAndBytesConfig(load_in_4bit=True)
  • failure: torch.finfo() receives torch.uint8
  • affected modules: Gemma4AudioFeedForward, Gemma4AudioLightConv1d, Gemma4AudioLayer

The issue describes the root cause as audio modules performing clipping via torch.finfo(weight.dtype), while 4-bit quantization can make weight.dtype equal to torch.uint8.

That is a very close match to your traceback.


Relevant PR: skip audio modules during conversion

There is also a related transformers PR:

The direction of that PR is important: it does not primarily try to patch every torch.finfo(...) call. It moves toward excluding audio modules from quantization conversion.

That is why the practical workaround is to skip:

"audio_tower"
"embed_audio"
"model.audio_tower"
"model.embed_audio"

This is usually a good trade-off because the main memory saving comes from quantizing the LLM body, while the audio tower is much smaller and more sensitive for ASR quality.


Separate pitfall: manually include lm_head

There is a second issue that matters once you add llm_int8_skip_modules:

In affected versions, if you pass any custom skip list, the default “do not quantize” list may be cleared. That default list normally protects lm_head.

So this may fix the audio tower but introduce another problem:

llm_int8_skip_modules=["model.audio_tower"]

Safer:

llm_int8_skip_modules=[
    "lm_head",
    "audio_tower",
    "embed_audio",
    "model.audio_tower",
    "model.embed_audio",
]

That is why I recommend including lm_head explicitly.


Recommended corrected model-loading code

import torch
from transformers import AutoProcessor, AutoModelForMultimodalLM, BitsAndBytesConfig

MODEL_ID = "google/gemma-4-E2B-it"
# MODEL_ID = "google/gemma-4-E4B-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    llm_int8_skip_modules=[
        "lm_head",
        "audio_tower",
        "embed_audio",
        "model.audio_tower",
        "model.embed_audio",
    ],
)

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

model = AutoModelForMultimodalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.bfloat16,
    quantization_config=bnb_config,
).eval()

Fallback if dtype= is not accepted:

model = AutoModelForMultimodalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
).eval()

Verify that the workaround actually worked

Do not assume the skip list worked. Check the loaded module tree.

def inspect_audio_quantization(model):
    suspicious = []

    for name, module in model.named_modules():
        lname = name.lower()

        if "audio" not in lname and "embed_audio" not in lname:
            continue

        cls_name = module.__class__.__name__
        weight = getattr(module, "weight", None)
        dtype = getattr(weight, "dtype", None)

        if "4bit" in cls_name.lower() or "8bit" in cls_name.lower():
            suspicious.append((name, cls_name, dtype))
            continue

        if dtype is not None and not dtype.is_floating_point:
            suspicious.append((name, cls_name, dtype))

    if suspicious:
        print("WARNING: audio-related modules may still be quantized:")
        for name, cls_name, dtype in suspicious:
            print(f"{name}: {cls_name}, weight dtype={dtype}")
    else:
        print("OK: no obvious quantized audio-related modules found.")

inspect_audio_quantization(model)

Bad output would look like:

model.audio_tower.layers.0.feed_forward1.ffw_layer_1.linear: Linear4bit, weight dtype=torch.uint8

If you see that, the audio tower is still quantized.

Print actual module names:

for name, module in model.named_modules():
    if "audio" in name.lower() or "embed_audio" in name.lower():
        weight = getattr(module, "weight", None)
        dtype = getattr(weight, "dtype", None)
        print(name, module.__class__.__name__, dtype)

Then update llm_int8_skip_modules using the real top-level names.


Non-quantized sanity test

Before debugging quantization, it is useful to confirm that the audio pipeline works without 4-bit.

import torch
from transformers import AutoProcessor, AutoModelForMultimodalLM

MODEL_ID = "google/gemma-4-E2B-it"

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

model = AutoModelForMultimodalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.bfloat16,
).eval()

Then test one audio file:

messages = [
    {
        "role": "user",
        "content": [
            {"type": "audio", "audio": "<path-to-test.wav>"},
            {
                "type": "text",
                "text": (
                    "Transcribe the following speech segment in English into English text. "
                    "Only output the transcription, with no newlines and no extra text."
                ),
            },
        ],
    }
]

inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
).to(model.device, dtype=model.dtype)

input_len = inputs["input_ids"].shape[-1]

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=False,
    )

response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
print(response)

If this works, your processor, prompt, audio file handling, and model class are basically valid. If the 4-bit version then fails, the failure is specifically in the quantized path.


Revised transcription function

Your function is structurally fine, but I would change four things:

  1. Use unique temp filenames.
  2. Delete temp files in finally.
  3. Move inputs with dtype=model.dtype.
  4. Start with batch_size=1.
import os
import io
import uuid
import torch
import torchaudio
import soundfile as sf

def transcribe_batch(batch):
    transcriptions = []

    for f in batch["audio"]:
        tmp_path = f"/tmp/gemma4_audio_{uuid.uuid4().hex}.wav"

        try:
            wav, sr = sf.read(io.BytesIO(f["bytes"]), dtype="float32")

            if wav.ndim > 1:
                wav = wav.mean(axis=1)

            if sr != 16000:
                wav_tensor = torch.tensor(wav, dtype=torch.float32)
                wav = torchaudio.functional.resample(
                    wav_tensor,
                    orig_freq=sr,
                    new_freq=16000,
                ).numpy()
                sr = 16000

            sf.write(tmp_path, wav, sr)

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "audio", "audio": tmp_path},
                        {
                            "type": "text",
                            "text": (
                                "Transcribe the following speech segment in English into English text. "
                                "Only output the transcription, with no newlines and no extra text."
                            ),
                        },
                    ],
                }
            ]

            inputs = processor.apply_chat_template(
                messages,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
                add_generation_prompt=True,
            ).to(model.device, dtype=model.dtype)

            input_len = inputs["input_ids"].shape[-1]

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=64,
                    do_sample=False,
                )

            response = processor.decode(
                outputs[0][input_len:],
                skip_special_tokens=True,
            ).strip()

            transcriptions.append(response)

        finally:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)

    batch["transcribed_text"] = transcriptions
    return batch

Smoke-test before full dataset mapping

Do not start with:

batch_size=32

That probably did not cause the torch.finfo() crash, but it makes debugging much harder.

Start with two samples:

processed_test = irish_male_dataset.select(range(2)).map(
    transcribe_batch,
    batched=True,
    batch_size=1,
    remove_columns=["audio"],
    load_from_cache_file=False,
    desc="Smoke test",
)

print(processed_test["transcribed_text"])

Then try a small subset:

processed_test = irish_male_dataset.select(range(16)).map(
    transcribe_batch,
    batched=True,
    batch_size=1,
    remove_columns=["audio"],
    load_from_cache_file=False,
    desc="Small test",
)

Then run the full dataset:

processed_dataset = irish_male_dataset.map(
    transcribe_batch,
    batched=True,
    batch_size=1,
    remove_columns=["audio"],
    load_from_cache_file=False,
    desc="Transcribing Audios by Irish Men",
)

df = pd.DataFrame(processed_dataset)
df.to_csv("irish_male_transcribed_gemma-4-e2b-it.csv", index=False)

After correctness is confirmed, increase gradually:

batch_size=2
batch_size=4
batch_size=8

I would not jump straight back to 32 until memory and output quality are stable.


Installation advice

I would avoid repeated downgrade/upgrade cycles in the same notebook runtime. Use a clean runtime, install, then restart the kernel.

Example:

pip uninstall -y torch torchvision torchaudio torchcodec transformers accelerate bitsandbytes datasets librosa soundfile

pip install -U torch torchvision torchaudio torchcodec \
  --index-url https://download.pytorch.org/whl/cu128

pip install -U transformers accelerate bitsandbytes datasets librosa soundfile

Then print versions:

import sys
import torch
import transformers
import accelerate
import datasets
import bitsandbytes as bnb

print("python:", sys.version)
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("torch cuda:", torch.version.cuda)
print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
print("transformers:", transformers.__version__)
print("accelerate:", accelerate.__version__)
print("datasets:", datasets.__version__)
print("bitsandbytes:", getattr(bnb, "__version__", "unknown"))

I would not blindly downgrade transformers. Gemma 4 support is recent, and older versions may not know the Gemma 4 processor/model mappings. That explains why downgrading can cause AutoProcessor or model-loading problems.

Relevant docs and references:


Things I would avoid

1. Do not monkey-patch torch.finfo

Avoid this except as a temporary local experiment:

_original_finfo = torch.finfo

def safe_finfo(dtype):
    if not dtype.is_floating_point:
        return _original_finfo(torch.bfloat16)
    return _original_finfo(dtype)

torch.finfo = safe_finfo

Why this is risky:

  • it changes global PyTorch behavior;
  • it can affect unrelated packages;
  • it hides the crash without un-quantizing the audio tower;
  • it can make output quality worse or harder to diagnose;
  • the upstream fix direction is to skip audio modules, not to change PyTorch globally.

2. Do not simply replace torch.finfo with torch.iinfo

The error text mentions torch.iinfo, but in this model path the integer dtype is probably quantized storage. The model wants a floating-point clamp bound, not the integer range 0..255.

So this is not a good model-level fix:

torch.iinfo(torch.uint8).max

3. Do not assume bnb_4bit_compute_dtype=torch.bfloat16 keeps all weights bf16

It controls the compute dtype. It does not guarantee every .weight.dtype remains floating-point.


Final recommendation

For your exact case, I would do this:

  1. Use a recent transformers version that supports Gemma 4.
  2. Use AutoModelForMultimodalLM, not AutoModelForCausalLM, for audio input.
  3. Quantize the main LLM body with bitsandbytes 4-bit.
  4. Skip audio_tower, embed_audio, and lm_head.
  5. Verify the skip list with model.named_modules().
  6. Test one audio file.
  7. Test irish_male_dataset.select(range(2)) with batch_size=1.
  8. Scale gradually.

The one-line diagnosis:

Your 4-bit config is probably quantizing Gemma 4’s audio tower; Gemma 4’s audio code then calls torch.finfo() on bitsandbytes’ torch.uint8 packed storage dtype, so the workaround is to keep audio modules unquantized while quantizing the language backbone.


Useful links

Thanks alot! I can’t believe it’s finally solved!!! Hooof
Just it seems that there are some minor bugs in the revised transcribe_batch, so I just stick to the previous function..
btw
Very nice explanation and important points