Lazy model initialization

Hello. How can I create a model object and skip the random initialization of weights? The random initialization is time consuming and unnecessary for my case, as I want to load the weights using torch.load_state_dict. For instance, see the code below.

config = BloomConfig.from_pretrained("bigscience/bloom")
block = BloomBlock(config)  # initializes weights randomly, which is time consuming
block.load_state_dict(torch.load("path_to_pytorch_bin"))

same question

Use the init_empty_weights ContextManager from accelerate

from accelerate import init_empty_weights
with init_empty_weights():
   block = BloomBlock(config)

You can set low_cpu_mem_usage=True which will skip that: Models