Accelerator() causes Error

I am trying to use the accelerate library with transformers.Trainer. However, I always get an error message when using multiple GPUs. Here is the minimum version of my train.py file:

import os
import hydra
from omegaconf import DictConfig
from peft import LoraConfig, get_peft_model, TaskType

from transformers import DataCollatorForSeq2Seq, AutoTokenizer, AutoConfig, Trainer, TrainingArguments, AutoModelForCausalLM

from accelerate import Accelerator

from utils.data import Dataset
import utils


@hydra.main(config_path="config/train", config_name="numbersorting", version_base="1.3")
def main(cfg: DictConfig):
    # set seed
    utils.seed_everything(cfg.run.seed)

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer.name, token=HF_TOKEN, cache_dir=model_dir)
    tokenizer.pad_token = tokenizer.eos_token

    # data. Dataset.__getitem__() returns a dictionary 'input_ids' and 'labels' and 'attention_mask' for each example.
    # return values are torch tensors on cpu.
    train_dataset = Dataset(...)
    test_dataset = Dataset(...)

    # get model
    config = AutoConfig.from_pretrained(cfg.model.architecture, load_in_8_bit=cfg.model.load_weights_in_8_bit)
    model = AutoModelForCausalLM.from_pretrained(cfg.model.architecture, config=config, cache_dir=model_dir)

    # data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)

    # LoRA
    if cfg.training.lora.enabled:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=cfg.training.lora.rank,
            lora_alpha=cfg.training.lora.alpha,
            lora_dropout=cfg.training.lora.dropout,
        )
        model = get_peft_model(model, peft_config)
        print("Trainable Parameters:")
        model.print_trainable_parameters()

    # Initialize the Accelerator
    accelerator = Accelerator(mixed_precision=cfg.training.precision)

    # Configure model, tokenizer, optimizers, and data collators to work with Accelerate
    model, tokenizer, data_collator = accelerator.prepare(model, tokenizer, data_collator)

    # Trainer
    args = TrainingArguments(
        output_dir=model_dir,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        evaluation_strategy="steps",
        eval_steps=100,
        logging_steps=100,
        save_steps=100,
        push_to_hub=False,
        do_eval=True,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
    )

    # training
    trainer.train()
    trainer.save_model(model_dir)


if __name__ == "__main__":
    HF_TOKEN = os.environ["HUGGINGFACE_TOKEN"]
    model_dir = os.path.join("path/to/artifacts", "models")
    main()

I call this using CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py hydra_args=value.

I keep getting this error message:

...
  File "/homes/55/cornelius/anaconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/homes/55/cornelius/anaconda3/envs/llm/lib/python3.11/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

The call stack indicates that this is called from _inner_training_loop.

What am I doing wrong? I thought that Accelerator(), DataCollator and Trainer would handle everything automatically.

PS the model also doesn’t react to “load_in_8_bit=True”. I don’t know if this is related.

Don’t use the Accelerator, Trainer uses it internally now so there’s no need! :smiley:

1 Like

Thank you. It works now. Any idea why this error occurs?