Pipeline device issue, torch_xla generation() bug, flax models malloc errors

I am inferring llms for certain text generation tasks with tpu. i originally wrote code using transformers pipeline with device_map set to “auto”. but it never pick up the tpu device. then i give it as xm.xla_device() as parameters but after doing that system, automatically crashes.

after that i wrote code for generation. the models loaded on device but the generation was happening on cpu. this problem persist for other users too. referring to this

Generate text with `model.generate` on TPU does not work · Issue #12322 · huggingface/transformers · GitHub .

then i tried to change torch model to flax. i used FlaxAutoModelForCausalLM for converting pytorch model to flax model. since flax version where not available. but while converting torch model to flax model. i got error for XLA buffer. which stated i had 29.7 MB space and i need 32.0 MB space for buffer. i tried to change environment variable for it but the buffer size did not changed.

now i am currently trying to use jax models which where re-created by this person from pytoch to jax.

GitHub - ayaka14732/llama-2-jax: JAX implementation of the Llama 2 model .

thank you.