I am inferring llms for certain text generation tasks with tpu. i originally wrote code using transformers pipeline with device_map set to âautoâ. but it never pick up the tpu device. then i give it as xm.xla_device() as parameters but after doing that system, automatically crashes.
after that i wrote code for generation. the models loaded on device but the generation was happening on cpu. this problem persist for other users too. referring to this
then i tried to change torch model to flax. i used FlaxAutoModelForCausalLM for converting pytorch model to flax model. since flax version where not available. but while converting torch model to flax model. i got error for XLA buffer. which stated i had 29.7 MB space and i need 32.0 MB space for buffer. i tried to change environment variable for it but the buffer size did not changed.
now i am currently trying to use jax models which where re-created by this person from pytoch to jax.
GitHub - ayaka14732/llama-2-jax: JAX implementation of the Llama 2 model .
thank you.