Fine-tuning llama with accelerator: ValueError: Expected input batch_size (444) to match target batch_size (3)

Hi, @sgugger , and everyone, when I’m finetuning llama with accelerator, the following error occurs.
I follow the data process and accelerator setting of the official tutorial here.
I print the first shape of the input, is seems right, but don’t know why it changes shape in the crossentropy layer. Thanks for anyone who can help.

ile "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 714, in forward
    loss = loss_fct(shift_logits, shift_labels)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (508) to match target batch_size (3).
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (444) to match target batch_size (3).
***** Running training *****
labels shape: torch.Size([4])
input_ids shape: {torch.Size([4, 112])}
attention_mask shape: {torch.Size([4, 112])}
Traceback (most recent call last):
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 364, in <module>
    main()
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 316, in main
    outputs = model(**batch, use_cache=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
Traceback (most recent call last):
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 364, in <module>
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/peft/peft_model.py", line 575, in forward
    return self.base_model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    main()
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 316, in main
    outputs = model(**batch, use_cache=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 714, in forward
    loss = loss_fct(shift_logits, shift_labels)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
Traceback (most recent call last):
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 364, in <module>
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    main()
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 316, in main
    loss = self.module(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return F.cross_entropy(input, target, weight=self.weight,
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    outputs = model(**batch, use_cache=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/peft/peft_model.py", line 575, in forward
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    return self.base_model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (380) to match target batch_size (3).
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 714, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    loss = loss_fct(shift_logits, shift_labels)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/peft/peft_model.py", line 575, in forward
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    return self.base_model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
Traceback (most recent call last):
    return F.cross_entropy(input, target, weight=self.weight,
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 714, in forward
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 364, in <module>
    main()
    loss = loss_fct(shift_logits, shift_labels)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  File "/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py", line 316, in main
    outputs = model(**batch, use_cache=False)
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (476) to match target batch_size (3).
    return forward_call(*args, **kwargs)  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)

  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1174, in forward
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    loss = self.module(*inputs, **kwargs)
    return F.cross_entropy(input, target, weight=self.weight,
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/peft/peft_model.py", line 575, in forward
    return self.base_model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 714, in forward
    loss = loss_fct(shift_logits, shift_labels)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (508) to match target batch_size (3).
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (444) to match target batch_size (3).
wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing.
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 82679 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 1 (pid: 82680) of binary: /opt/conda/bin/python3
Traceback (most recent call last):
  File "/opt/conda/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/commands/launch.py", line 926, in launch_command
    deepspeed_launcher(args)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/commands/launch.py", line 671, in deepspeed_launcher
    distrib_run.run(args)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/workspace/work00/sue-xie/llama/kyano_lora/src/main_paws_notwork.py FAILED
------------------------------------------------------------
============================================================
root@4503be23fd54:/workspace/work00/sue-xie/llama/kyano_lora# wandb: 🚀 View run silver-bird-41 at: https://wandb.ai/xie-suchun-p7/llama_lora_formal_test_1/runs/yk21vl4i
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230710_021825-yk21vl4i/logs

My code here:
abot data process


def create_dataset(
    local_rank, output_path,seed, tokenizer, max_seq_len, train:bool):

    # load data
    en_train = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/en_train")
    en_valid = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/en_valid")
    en_test = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/en_test")
    # japanese
    ja_train = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/ja_train")
    ja_valid = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/ja_valid")
    ja_test = load_from_disk("/workspace/work00/sue-xie/llama/dataset/pawsx/ja_test")


    # combine to en_pawx & ja_pawx
    en_pawx = DatasetDict({
        'train': en_train,
        'valid': en_valid,
        'test': en_test
    })
    ja_pawx = DatasetDict({
        'train': ja_train,
        'valid': ja_valid,
        'test': ja_test
    })
    def replacy_label(example):
            if example['target'] == "No":
                example['target'] = example['target'].replace("No","0")
                example['target'] = int(example['target'])
            else:
                example['target'] = example['target'].replace("Yes","1")
                example['target'] = int(example['target'])
            return  example

    #def preprocess_tok(examples, max_seq_len=512):
    def preprocess_tok(examples, max_seq_len=512):
            return tokenizer(examples["input"],max_length=max_seq_len, padding=False, truncation=True)
            #return tokenizer(examples["input"], padding="max_length",max_length=max_seq_len,truncation=True)
    # def label_token(examples, max_seq_len=1):
    #         #return tokenizer(examples["input"],max_length=max_seq_len, padding=False, truncation=True)
    #         return tokenizer(examples["labels"], padding="max_length",max_length=max_seq_len,truncation=True)


    en_pawx= en_pawx.map(replacy_label)

    en_tokened = en_pawx.map(
        preprocess_tok,
        batched = True,
    )

    en_tokened = en_tokened.rename_column("target","labels")
    #features: ['input', 'labels', 'input_ids', 'attention_mask']
    en_encoded=en_tokened.remove_columns(["input"])
    en_encoded.set_format("torch")
    
    en_encoded_train = en_encoded["train"]
    en_encoded_valid = en_encoded["valid"]
    en_train_ = en_encoded_train.shuffle(seed=seed)
    en_valid_ = en_encoded_valid.shuffle(seed=seed)

    if train == True:
        return en_train_
    else: 
        return en_valid_
    

about dataloader and model:

ps: I cloned the code form others, and it works fine with the original data. but when I changed the datasets process and dataloader padding, he’s using dataloader for seqtoseq, but my data is like “input”,“label”(0 or 1), so I changed the dataloader.
But the error don’t go.


import evaluate
metric = evaluate.load("accuracy")
def evaluate(args, model, eval_dataloader, accelerator, eval_dataset):
    model.eval()
    losses = []
    
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

        logits = outputs.logits
        predictions = torch.argmax(logit, dim=-1)
        metric.add_batch(predictions=predictions, reference = batch["labels"])

    losses = torch.cat(losses)
    metric.compute()
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float("inf")
    return perplexity, eval_loss,metric.compute()

from transformers import AutoTokenizer, DataCollatorWithPadding

#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,pad_to_multiple_of=8,mlm=False)


#metric = evaluate.load("accuracy")

def main():
    args = parse_args()
    
    accelerator = Accelerator(log_with="wandb")

    hps = {"learning_rate": args.learning_rate}
    accelerator.init_trackers(args.wandb_name)

    set_random_seed(args.seed)

    tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path,
                                               fast_tokenizer=True)
    tokenizer.pad_token = tokenizer.eos_token
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer,pad_to_multiple_of=8,
         padding = True)
    # tokenizer.pad_token_id = (
    #     0
    # )
    # tokenizer.padding_side = "left"

    # tokenizer.pad_token = tokenizer.eos_token

    model = create_hf_model(LlamaForCausalLM, args.model_name_or_path,
                            tokenizer)
    
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_dim, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)

    model = get_peft_model(model, peft_config)
    with accelerator.main_process_first():
        train_dataset = create_dataset(
            args.local_rank, # invalid
            args.data_output_path,
            args.seed,
            tokenizer,
            args.max_seq_len,
            True,
            #imitation_model=args.imitation_model # invalid
        )   
        #print(train_dataset[0])
        eval_dataset = create_dataset(
            args.local_rank,
            args.data_output_path,
            args.seed,
            tokenizer,
            args.max_seq_len,
            False,
            #imitation_model=args.imitation_model
        )

    accelerator.wait_for_everyone()

    # DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, collate_fn=data_collator,
        batch_size= args.per_device_train_batch_size
        )


    eval_dataloader = DataLoader(
        eval_dataset, collate_fn=data_collator,
        batch_size= args.per_device_eval_batch_size
        )


    # train_dataloader = DataLoader(train_dataset,
    #                             #   collate_fn=DataCollatorForSeq2Seq(
    #                             #     tokenizer, pad_to_multiple_of=8,return_tensors="pt",padding=True
    #                             #   ),
    #                             #data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    #                             batch_size=args.per_device_train_batch_size)
    # eval_dataloader = DataLoader(eval_dataset,
    #                             #  collate_fn=DataCollatorForSeq2Seq(
    #                             #     tokenizer,pad_to_multiple_of=8,
    #                             #     return_tensors="pt",padding=True
    #                             #  ),
    #                              batch_size=args.per_device_eval_batch_size)
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    # New Code #
    # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer
    optimizer_cls = (
        torch.optim.AdamW
        if accelerator.state.deepspeed_plugin is None
        or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
        else DummyOptim
    )

    optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)

    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
    )

    model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, eval_dataloader, optimizer, lr_scheduler)

   
    # Train!
    print_rank_0("***** Running training *****", accelerator.process_index)
    
    for epoch in range(args.num_train_epochs):
        current_step = []
        model.train()
        for step, batch in enumerate(train_dataloader):
            #sue
            if step == 0:
                print(f'labels shape: {batch["labels"].shape}')
                print("input_ids shape:",{batch["input_ids"].shape})
                print("attention_mask shape:",{batch["attention_mask"].shape})
                
            outputs = model(**batch, use_cache=False)
            #outputs['logits'].float(), inputs['labels'].float()
            train_loss = outputs.loss
            accelerator.backward(train_loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            accelerator.log({"train_loss": train_loss})
            accelerator.log({"lr": lr_scheduler.get_lr()[0]})
            if step % 300 == 0:
                print_rank_0(f"Epoch is {epoch}, Step is {step}, train_loss is {train_loss.item()}", accelerator.process_index)
            
            
        
        ppl, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset)
        if accelerator.is_main_process:
            print_rank_0(f"eval_loss: {eval_loss}, ppl: {ppl}", accelerator.process_index)
   
       
    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)

        # New Code #
        # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
        # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
        # `zero3_save_16bit_model` is True in DeepSpeed Plugin.
        # For Zero Stages 1 and 2, models are saved as usual in the output directory.
        # The model name saved is `pytorch_model.bin`
        unwrapped_model.save_pretrained(
            args.output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=accelerator.get_state_dict(model),
        )
        if accelerator.is_main_process:
            tokenizer.save_pretrained(args.output_dir)
        

    accelerator.end_training()
if __name__ == "__main__":
    main()