Hi All, Trying to use gemma with BitsAndBytes config for quanitzation. Then further wrapping in pytorch lighning module, and inferencing with ddp strategy continiously getting error of:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
Same issue been experienced with LLamaFroConditionalGeneration. For Reference, not passing quntization_config on initialization as(.from_pretrained) rest of code is exactly same, everything works correctly.
Code for reference:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_type=torch.bfloat16
)
gemma = PaliGemmaForConditionalGeneration.from_pretrained(FINETUNED_MODEL_ID,
quantization_config=bnb_config
)
model = gem(model = gemma, processor = processor, config = config)
trainer = Trainer(
accelerator='gpu',
devices=7,
strategy='ddp',
# fast_dev_run=True,
)
predictions = trainer.predict(model, dataloaders=dm.val_dataloader())
gem is just pytorch lighning wrapper, that works!!! without bits and bytes config.
Please let me know if anyone experiencing issue with this as well