Im trying to use the paligemma model according to the model card here:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Instruct the model to create a caption in Spanish
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt", padding=True).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\feature_extraction_utils.py:182, in BatchFeature.convert_to_tensors(self, tensor_type)
181 if not is_tensor(value):
--> 182 tensor = as_tensor(value)
184 self[key] = tensor
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\feature_extraction_utils.py:141, in BatchFeature._get_is_as_tensor_fns.<locals>.as_tensor(value)
140 value = np.array(value)
--> 141 return torch.tensor(value)
RuntimeError: Could not infer dtype of numpy.float32
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[4], line 23
21 # Instruct the model to create a caption in Spanish
22 prompt = "caption es"
---> 23 model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
24 input_len = model_inputs["input_ids"].shape[-1]
26 with torch.inference_mode():
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\models\paligemma\processing_paligemma.py:250, in PaliGemmaProcessor.__call__(self, text, images, tokenize_newline_separately, padding, truncation, max_length, return_tensors, do_resize, do_normalize, image_mean, image_std, data_format, input_data_format, resample, do_convert_rgb, do_thumbnail, do_align_long_axis, do_rescale, suffix)
238 suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
240 input_strings = [
241 build_string_from_input(
242 prompt=prompt,
(...)
247 for prompt in text
248 ]
--> 250 pixel_values = self.image_processor(
251 images,
252 do_resize=do_resize,
253 do_normalize=do_normalize,
254 return_tensors=return_tensors,
255 image_mean=image_mean,
256 image_std=image_std,
257 input_data_format=input_data_format,
258 data_format=data_format,
259 resample=resample,
260 do_convert_rgb=do_convert_rgb,
261 )["pixel_values"]
263 if max_length is not None:
264 max_length += self.image_seq_length # max_length has to account for the image tokens
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\image_processing_utils.py:551, in BaseImageProcessor.__call__(self, images, **kwargs)
549 def __call__(self, images, **kwargs) -> BatchFeature:
550 """Preprocess an image or a batch of images."""
--> 551 return self.preprocess(images, **kwargs)
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\models\siglip\image_processing_siglip.py:259, in SiglipImageProcessor.preprocess(self, images, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, return_tensors, data_format, input_data_format, do_convert_rgb, **kwargs)
254 images = [
255 to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
256 ]
258 data = {"pixel_values": images}
--> 259 return BatchFeature(data=data, tensor_type=return_tensors)
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\feature_extraction_utils.py:78, in BatchFeature.__init__(self, data, tensor_type)
76 def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
77 super().__init__(data)
---> 78 self.convert_to_tensors(tensor_type=tensor_type)
File c:\Users\voelkev\GitHub\PaliGemma\.venv\Lib\site-packages\transformers\feature_extraction_utils.py:188, in BatchFeature.convert_to_tensors(self, tensor_type)
186 if key == "overflowing_values":
187 raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
--> 188 raise ValueError(
189 "Unable to create tensor, you should probably activate padding "
190 "with 'padding=True' to have batched tensors with the same length."
191 )
193 return self
ValueError: Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.
Package Versions:
Package Version
------------------ ------------
accelerate 0.31.0
asttokens 2.4.1
bitsandbytes 0.43.1
certifi 2024.6.2
charset-normalizer 2.0.12
colorama 0.4.6
comm 0.2.2
debugpy 1.8.2
decorator 5.1.1
executing 2.0.1
filelock 3.15.4
fsspec 2024.6.0
huggingface-hub 0.23.4
idna 3.7
inquirerpy 0.3.4
intel-openmp 2021.4.0
ipykernel 6.29.4
ipython 8.25.0
ipywidgets 8.1.3
jedi 0.19.1
Jinja2 3.1.4
jupyter_client 8.6.2
jupyter_core 5.7.2
jupyterlab_widgets 3.0.11
MarkupSafe 2.1.5
matplotlib-inline 0.1.7
mkl 2021.4.0
mpmath 1.3.0
nest-asyncio 1.6.0
networkx 3.3
numpy 2.0.0
packaging 24.1
parso 0.8.4
pfzy 0.3.4
pillow 10.3.0
pip 24.1
platformdirs 4.2.2
prompt_toolkit 3.0.47
psutil 6.0.0
pure-eval 0.2.2
Pygments 2.18.0
python-dateutil 2.9.0.post0
pywin32 306
PyYAML 6.0.1
pyzmq 26.0.3
regex 2024.5.15
requests 2.27.1
safetensors 0.4.3
setuptools 65.5.0
six 1.16.0
stack-data 0.6.3
sympy 1.12.1
tbb 2021.13.0
tokenizers 0.19.1
torch 2.3.1+cu118
torchaudio 2.3.1+cu118
torchvision 0.18.1+cu118
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
transformers 4.41.2
typing_extensions 4.12.2
urllib3 1.26.19
wcwidth 0.2.13
widgetsnbextension 4.0.11
Padding seems to enabled.