How to properly UPCAST the model weights to float32?

from transformers import FlaxGemmaForCausalLM
import jax.numpy as jnp 

model_name = "google/gemma-1.1-2b-it"

# Load the model with desired data type (bfloat16 for reduced memory usage)
model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=jnp.bfloat16, token=access_token)
# model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=jnp.float32, token=access_token)
# model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, token=access_token)

WARNING:

Some of the weights of FlaxGemmaForCausalLM were initialized in bfloat16 precision from the model checkpoint at google/gemma-1.1-2b-it: [('model', 'embed_tokens', 'embedding'), ......
('model', 'norm', 'weight')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.

I’m encountering a warning message every time I try to load the model. The message says: “Some of the weights of the FlaxGemmaForCausalLM were initialized in bfloat16 precision…”. I have two questions:

  1. How can I suppress this warning message? Is it safe to ignore it, or are there potential consequences?
  2. How can I ensure all model weights are explicitly upcast to float32 for potentially higher accuracy? model = model.to_fp32(params) is this correct, and how can I use it effectively?

Additionally, I’d like to know how to get a summary of the model architecture and key parameters.

summary of model:

print(model)

Your usage seems correct, does the warning disappear when you run it?

  • After applying model = model.to_fp32(params) the warning message disappear. However, it changed it in to dictionary and generate raises below error:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-29-442f22ce8468> in <cell line: 1>()
----> 1 _ = p_generate(inputs, params, max_new_tokens)

    [... skipping hidden 12 frame]

<ipython-input-27-b3ddd8e3109f> in generate(inputs, params, max_new_tokens)
      1 def generate(inputs, params, max_new_tokens):
----> 2     generated_ids = model.generate(
      3         inputs["input_ids"],
      4         attention_mask=inputs["attention_mask"],
      5         params=params,

AttributeError: 'dict' object has no attribute 'generate'
  • Before applying model = model.to_fp32(params), print(model) : <transformers.models.gemma.modeling_flax_gemma.FlaxGemmaForCausalLM object at 0x7ef9d870a4a0>
  • After applying model = model.to_fp32(params), print(model) :
    list of dictionaries, however, i’m looking for summary of the model architecture.