from transformers import FlaxGemmaForCausalLM
import jax.numpy as jnp
model_name = "google/gemma-1.1-2b-it"
# Load the model with desired data type (bfloat16 for reduced memory usage)
model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=jnp.bfloat16, token=access_token)
# model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=jnp.float32, token=access_token)
# model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, token=access_token)
WARNING:
Some of the weights of FlaxGemmaForCausalLM were initialized in bfloat16 precision from the model checkpoint at google/gemma-1.1-2b-it: [('model', 'embed_tokens', 'embedding'), ......
('model', 'norm', 'weight')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
I’m encountering a warning message every time I try to load the model. The message says: “Some of the weights of the FlaxGemmaForCausalLM were initialized in bfloat16 precision…”. I have two questions:
- How can I suppress this warning message? Is it safe to ignore it, or are there potential consequences?
- How can I ensure all model weights are explicitly upcast to float32 for potentially higher accuracy?
model = model.to_fp32(params)
is this correct, and how can I use it effectively?
Additionally, I’d like to know how to get a summary of the model architecture and key parameters.