Sharded checkpoints

This link show how to can set memory limits using device_map.
But before you can do that you need a sharded checkpoint already for the below function. How do you get sharded checkpoints if the model can’t fit on your gpu’s to start off with?

The whole reason i’m doing this is because when i use the shard option i get cuda out of memory errors.

  model = load_checkpoint_and_dispatch(
      model, "sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"]

I figured it out, but the max memory mapping didn’t work anyway

I just edited my config for a t5 and loaded it up and saved it. This gave the sharded checkpoints.

  my_config = T5Config.from_pretrained("/media/New Volume/models/need to make/xl-nl32/")
  model = T5Model(my_config)
  model.save_pretrained("/media/New Volume/models/need to make/xl-nl32made", from_pt=True)

Then with accelerate

checkpoint = "/media/New Volume/models/need to make/xl-nl32made"
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
    model = T5Model(config)

my_device_map = infer_auto_device_map(model, max_memory={0: "10GiB", 1: "10GiB", "cpu": "120GiB"})
model = load_checkpoint_and_dispatch(
    model, "/media/New Volume/models/need to make/xl-nl32made", device_map=my_device_map)

accelerator = Accelerator()
inputs = torch.load('')
label = torch.load('')
count = 0
steps = 10000
printSteps = 50
totalLoss = 0

device = accelerator.device

inputs, label, model = accelerator.prepare(inputs, label, model)
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=0.00001)
optimizer = accelerator.prepare(optimizer)

with accelerator.accumulate(model):
    for ihjfdls in inputs:
        start_time = time.time()

        i = inputs[count]
        j = label[count]

        i = torch.from_numpy(i)
        j = torch.from_numpy(j)
        i = i.type(torch.long)
        j = j.type(torch.long)
        i = i.reshape( 1, -1)
        j = j.reshape( 1, -1)
        loss = model( input_ids=i, labels=j, use_cache=False).loss
        count += 1

        totalLoss = totalLoss + loss.item()

        if count % printSteps == 0: 
            print("epoch:", count, "loss:", totalLoss/printSteps)
            if (totalLoss/printSteps) < 1.0: break
            totalLoss = 0

        if count == steps: break

        trainttime = time.time()

But this memory mapping didn’t work, I tried it by launching Accelerator with different configs and just by running the file alone.

Can you provide more towards what you did and how you tried setting your max memory mapping?

I updated my response