Hello.
I’m trying to inference Gemma any-2-any models and quantize them. I faced with lot’s of different errors due to versioning and installing packages. I tried many Github solutions, but if I downgrade transformers then the AutoProcessor could not be found.
!pip uninstall -y torch torchvision torchaudio torchcodec -q
!pip install -U \
"torch==2.11.0" \
"torchvision==0.26.0" \
"torchaudio==2.11.0" \
torchcodec \
--index-url https://download.pytorch.org/whl/cu128 -q
!pip install -U transformers librosa accelerate bitsandbytes -q
!pip install soundfile -q
!pip install "datasets==2.21.0" -q
import os
import io
import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
import datasets
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import logging
logging.set_verbosity(logging.CRITICAL)
import logging
logging.getLogger().setLevel(logging.ERROR)
import warnings
warnings.filterwarnings("ignore")
irish_male_dataset = load_dataset("ylacombe/english_dialects", "irish_male", split="train")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
MODEL_ID = "google/gemma-4-E2B-it"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=bnb_config,
).eval()
def transcribe_batch(batch):
transcriptions = []
for i, f in enumerate(batch["audio"]):
audio_path = f["path"]
if audio_path and os.path.exists(audio_path):
wav, sr = sf.read(io.BytesIO(f["bytes"]), dtype="float32")
else:
wav, sr = sf.read(io.BytesIO(f["bytes"]), dtype="float32")
if wav.ndim > 1:
wav = wav.mean(axis=1)
if sr != 16000:
wav = torchaudio.functional.resample(
torch.tensor(wav, dtype=torch.float32),
orig_freq=sr,
new_freq=16000
).numpy()
tmp_path = f"/tmp/tmp_audio_{i}.wav"
sf.write(tmp_path, wav, 16000)
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": tmp_path},
{"type": "text", "text": (
"Transcribe the following speech segment in English into English text. Don't generate additonal text."
)},
]
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=64)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
transcriptions.append(response)
torch.cuda.empty_cache()
batch["transcribed_text"] = transcriptions
print(transcriptions)
return batch
irish_male_dataset = irish_male_dataset.cast_column("audio", datasets.features.Audio(decode=False))
processed_dataset = irish_male_dataset.map(
transcribe_batch,
batched=True,
batch_size=32,
remove_columns=["audio"],
load_from_cache_file=False,
desc="Transcribing Audios by Irish Men"
)
df = pd.DataFrame(processed_dataset)
df.to_csv("irish_male_transcribed_gemma-4-e2b-it.csv", index=False)
Here is the full error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_57/390005493.py in <cell line: 0>()
----> 1 processed_dataset = irish_male_dataset.map(
2 transcribe_batch,
3 batched=True,
4 batch_size=32,
5 remove_columns=["audio"],
/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
600 self: "Dataset" = kwargs.pop("self")
601 # apply actual function
--> 602 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
603 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
604 for dataset in datasets:
/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
565 }
566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
569 # re-apply format to the output
/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3165 desc=desc or "Map",
3166 ) as pbar:
-> 3167 for rank, done, content in Dataset._map_single(**dataset_kwargs):
3168 if done:
3169 shards_done += 1
/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in _map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
3556 ) # Something simpler?
3557 try:
-> 3558 batch = apply_function_on_filtered_inputs(
3559 batch,
3560 indices,
/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py in apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
3425 if with_rank:
3426 additional_args += (rank,)
-> 3427 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
3428 if isinstance(processed_inputs, LazyDict):
3429 processed_inputs = {
/tmp/ipykernel_57/2288985515.py in transcribe_batch(batch)
44
45 with torch.no_grad():
---> 46 outputs = model.generate(**inputs, max_new_tokens=64)
47
48 response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
125
126 return decorate_context
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, custom_generate, **kwargs)
2558
2559 # 9. Call generation mode
-> 2560 result = decoding_method(
2561 self,
2562 input_ids,
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
2751
2752 prefill_consumed = False
-> 2753 outputs = self._prefill(
2754 input_ids,
2755 generation_config,
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _prefill(self, input_ids, generation_config, model_kwargs, is_first_iteration)
3797 **model_kwargs,
3798 )
-> 3799 return self(**model_inputs, return_dict=True)
3800
3801 # Chunked prefill (for very large contexts)
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1777 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1778 else:
-> 1779 return self._call_impl(*args, **kwargs)
1780
1781 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1788 or _global_backward_pre_hooks or _global_backward_hooks
1789 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790 return forward_call(*args, **kwargs)
1791
1792 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
898 if return_dict_passed is not None:
899 return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
901 if not return_dict and not isinstance(output, tuple):
902 output = output.to_tuple()
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, image_position_ids, video_position_ids, past_key_values, mm_token_type_ids, inputs_embeds, labels, use_cache, logits_to_keep, **kwargs)
2534 Passed through to the vision encoder for positional embedding computation.
2535 """
-> 2536 outputs = self.model(
2537 input_ids=input_ids,
2538 pixel_values=pixel_values,
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1777 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1778 else:
-> 1779 return self._call_impl(*args, **kwargs)
1780
1781 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1788 or _global_backward_pre_hooks or _global_backward_hooks
1789 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790 return forward_call(*args, **kwargs)
1791
1792 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
974 output = func(self, *args, **kwargs)
975 else:
--> 976 output = func(self, *args, **kwargs)
977 # Restore original config value
978 finally:
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
898 if return_dict_passed is not None:
899 return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
901 if not return_dict and not isinstance(output, tuple):
902 output = output.to_tuple()
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, past_key_values, mm_token_type_ids, inputs_embeds, use_cache, image_position_ids, video_position_ids, **kwargs)
2343 # Merge text and audio
2344 if input_features is not None and input_features_mask is not None:
-> 2345 audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
2346 audio_features = audio_output.pooler_output
2347 audio_mask_from_encoder = audio_output.attention_mask # True = valid
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
898 if return_dict_passed is not None:
899 return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
901 if not return_dict and not isinstance(output, tuple):
902 output = output.to_tuple()
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in get_audio_features(self, input_features, input_features_mask, **kwargs)
2438 )
2439
-> 2440 audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
2441 audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
2442
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1777 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1778 else:
-> 1779 return self._call_impl(*args, **kwargs)
1780
1781 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1788 or _global_backward_pre_hooks or _global_backward_hooks
1789 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790 return forward_call(*args, **kwargs)
1791
1792 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
974 output = func(self, *args, **kwargs)
975 else:
--> 976 output = func(self, *args, **kwargs)
977 # Restore original config value
978 finally:
/usr/local/lib/python3.12/dist-packages/transformers/utils/output_capturing.py in wrapper(self, *args, **kwargs)
246 # Run the forward
247 try:
--> 248 outputs = func(self, *args, **kwargs)
249 # Reset the states
250 finally:
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, input_features, attention_mask, **kwargs)
1965
1966 for encoder_layer in self.layers[: self.config.num_hidden_layers]:
-> 1967 hidden_states = encoder_layer(
1968 hidden_states,
1969 attention_mask=attention_mask,
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1777 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1778 else:
-> 1779 return self._call_impl(*args, **kwargs)
1780
1781 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1788 or _global_backward_pre_hooks or _global_backward_hooks
1789 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790 return forward_call(*args, **kwargs)
1791
1792 result = None
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, hidden_states, attention_mask, position_embeddings, **kwargs)
544 gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
545
--> 546 hidden_states = self.feed_forward1(hidden_states)
547 residual = hidden_states
548
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1777 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1778 else:
-> 1779 return self._call_impl(*args, **kwargs)
1780
1781 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1788 or _global_backward_pre_hooks or _global_backward_hooks
1789 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790 return forward_call(*args, **kwargs)
1791
1792 result = None
/usr/local/lib/python3.12/dist-packages/transformers/models/gemma4/modeling_gemma4.py in forward(self, hidden_states)
425 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
426 # This is needed to avoid any underflow/overflow issues when clipping
--> 427 gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
428
429 residual = hidden_states
TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'