Load_checkpoint_and_dispatch without heavy system memory usage

Iā€™m trying to load llama-13b for inference on a system with 24GB VRAM and 32GB system memory using load_checkpoint_and_dispatch. The model should fit in the amount of combined memory I have but it looks like load_checkpoint_and_dispatch starts by trying to load the whole model into system memory at full precision before moving anything to GPU, causing me to run out of system memory. Is there any way around this or is this just a limitation of the current implementation? The model is sharded so it seems like it should be able to load shards and move them to GPU one at a time until the GPU is full and only then start loading the shards meant to stay in system memory.

Hereā€™s my code:

checkpoint = "decapoda-research/llama-13b-hf"
model_index_path = hf_hub_download(checkpoint, "pytorch_model.bin.index.json")

tokenizer = LlamaTokenizer.from_pretrained(checkpoint)
with init_empty_weights():
    model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True).half()

device_map = infer_auto_device_map(
    model,
    max_memory={
        0: "20GiB",
        "cpu": "16GiB"
    },
)

model = load_checkpoint_and_dispatch(
    model,
    model_index_path,
    device_map=device_map,
    no_split_module_classes=["LlamaDecoderLayer"],
    dtype=torch.float16,
)

You should just use`

LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16)

which will do the same thing.

load_checkpoint_and_dispatch will load the scheckpoint shard by shard, but it canā€™t work with the code sample you provided as it requires the shards to live in the same folder as the model index, and it doesnā€™t look like you are downloading them?

Also that checkpoint does not work at all (see the ~70 PRs opened being ignored), you should really use the conversion script on the official weights or use other checkpoints.