Low bf16 performance on TPU, int4 vs int8 quantizatoin

hi, I am trying to finetune llama3 with LoRA and the most recent versions of peft, accelerate, torch and bitsandbytes, I struggle with the following:

  1. In the most recent version, bitsandbytes is missing

bnb_8bit_compute_dtype=torch.bfloat16

in the BitsAndBytesConfig for 8bit, they are present only for 4bit and I could not find out whether I can use them iterchangeably.

  1. Is there any comparison of the finetune perfomance between loading model in 4bit and 8bit

  2. For some reasong, setting compute dtype to bfloat16 in the config, makes it increadibly slow on TPU with the most recent version of the libraries (on kaggle). It is 10 times slower than doing exactly the same thing with float16!

bnb_4bit_compute_dtype=torch.bfloat16

  1. If I finetune a model with lora on GPU and use compute dtype bfloat16, then for inference I load it in float16 on TPU (because, again, bfloat is slow as hell). The quality will degrade, right? As LoRA also adapts for the rounding errors during quantization?

Thanks in advance!