Why does adamw_bnb_8bit skip updating embedding parameters?

i’m using whisper for testing, but this probably applies to other models since in whisper’s source code there is no specific instruction to elicit this behavior.

when using adamw_bnb_8bit as the optimizer for seq2seq tasks, i noticed it automatically turns off (at least according to the output below which pops up right at the beginning of the training) embedding parameters.

is this a quirk/property of bitsandbytes package? i haven’t read the dettmers et al.'s work, so not sure if this is expected behavior.

@sanchit-gandhi any thoughts?

(for whisper-small)

skipped Embedding(1500, 768): 1.0986328125M params
skipped Embedding(51865, 768, padding_idx=50257): 39.085693359375M params
skipped WhisperPositionalEmbedding(448, 768): 39.413818359375M params
skipped: 39.413818359375M params

Hey @ozanciga! Very interesting observation! My understanding is that embedding layers can become particularly unstable when downcast to lower precisions (fp16/fp8). It might be that there are exception clauses for embedding modules in the bitsandbytes package which prevents them from being downcast.

Indeed! It looks like we need a special embedding layer for 8bit embeddings: GitHub - TimDettmers/bitsandbytes: 8-bit CUDA functions for PyTorch

See step 3: Replace embedding layer if necessary: torch.nn.Embedding(..) -> bnb.nn.Embedding(..)

thank you @sanchit-gandhi , that seems about right. maybe a question for the hugging face devs, should this be handled on the backend automatically?

also, do you have any opinion on if freezing embeddings have any significant impact on the outcome? specifically asking for whisper but also interested in general since it’s very desirable to use 8bit in most cases.

actually i went through the source and turns out “skipped” means the optimizer is using 32bits for those parameters. i monkey patched the code to incorporate bnb.nn.stableembeddings, but i doubt it’s worth it for most cases.

Hey @ozanciga! Sorry for the late reply here. Indeed, it could be worth handling it automatically by the trainer. Feel free to open an issue on the transformers repo if you want to discuss how this might look! Or directly a PR if you’ve got an idea on how to fix it already. Think this would make for a nice PR to the repo! Happy to help you any questions around the issue/PR!

My intuition would be that freezing the embeddings won’t have a significant impact on the outcome. I believe in the Dalle-Mini project freezing the pre-trained embeddings actually gave superior performance vs non-frozen. Probably best to experiment here with say 1 epoch of data and see how eval WER performance compares with frozen / non-frozen embeddings

Now I ques it is automatically handled if trainer is correctly initialized in the main branch?
image