Loading BloomForCausalLM from sharded checkpoints

I also get and error, but slightly different.

Traceback (most recent call last):
  File "run_inference.py", line 12, in <module>
    model = load_checkpoint_and_dispatch(
  File "/home/anaconda3/envs/research/lib/python3.8/site-packages/accelerate/big_modeling.py", line 427, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/home/anaconda3/envs/research/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 748, in load_checkpoint_in_model
    raise ValueError(f"{param_name} doesn't have any device set.")
ValueError: word_embeddings.weight doesn't have any device set.