Hugging Face Trainer class with accelerate

Hello! I have the following python script used for fine-tuning LLMs:

from accelerate import Accelerator
from huggingface_hub import login
from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sagemaker.remote_function import remote
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import transformers

# Start training
def train_fn(
        model_name,
        train_ds,
        test_ds=None,
        lora_r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        num_train_epochs=1,
        chunk_size=2048,
        gradient_checkpointing=False,
        merge_weights=False,
        token=None
):  
    accelerator = Accelerator()
    
    if token is not None:
        login(token=token)

    # tokenize and chunk dataset
    with accelerator.main_process_first():
        lm_train_dataset = train_ds.map(
            lambda sample: tokenizer(sample["text"]), batched=True, batch_size=per_device_train_batch_size, remove_columns=list(train_ds.features)
        )

    # Print total number of samples
    print(f"Total number of train samples: {len(lm_train_dataset)}")


    if test_ds is not None:
        with accelerator.main_process_first():
            lm_test_dataset = test_ds.map(
                lambda sample: tokenizer(sample["text"]), batched=True, batch_size=per_device_eval_batch_size, remove_columns=list(test_ds.features)
            )

        print(f"Total number of test samples: {len(lm_test_dataset)}")
    else:
        lm_test_dataset = None

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        quantization_config=bnb_config,
        device_map={'':torch.cuda.current_device()},
        cache_dir="/tmp/.cache"
    )

    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, config)
    print_trainable_parameters(model)

    model = model.to(accelerator.device)

    if test_ds is not None:
        model, lm_train_dataset, lm_test_dataset = accelerator.prepare(
            model, lm_train_dataset, lm_test_dataset
        )
    else:
        model, lm_train_dataset = accelerator.prepare(
            model, lm_train_dataset
        )

    trainer = transformers.Trainer(
        model=model,
        train_dataset=lm_train_dataset,
        eval_dataset=lm_test_dataset if lm_test_dataset is not None else None,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            #gradient_accumulation_steps=gradient_accumulation_steps,
            gradient_checkpointing=gradient_checkpointing,
            logging_steps=2,
            num_train_epochs=num_train_epochs,
            learning_rate=learning_rate,
            bf16=True,
            save_strategy="no",
            output_dir="outputs"
        ),
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )
    model.gradient_checkpointing_enable()
    model.config.use_cache = False

    trainer.train()

    if merge_weights:
        output_dir = "/tmp/model"

        # merge adapter weights with base model and save
        # save int 4 model
        trainer.model.save_pretrained(output_dir, safe_serialization=False)
        # clear memory
        del model
        del trainer
        
        torch.cuda.empty_cache()

        # load PEFT model in fp16
        model = AutoPeftModelForCausalLM.from_pretrained(
            output_dir,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
            cache_dir="/tmp/.cache"
        )
        
        # Merge LoRA and base model and save
        model = model.merge_and_unload()
        model.save_pretrained(
            "/opt/ml/model", safe_serialization=True, max_shard_size="2GB"
        )
    else:
        model.save_pretrained("/opt/ml/model", safe_serialization=True)

    tmp_tokenizer = AutoTokenizer.from_pretrained(model_name)
    tmp_tokenizer.save_pretrained("/opt/ml/model")

train_fn(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    train_ds=train_dataset,
    test_ds=test_dataset,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    merge_weights=True,
    token="<HF_TOKEN>"
)

where

type(train_dataset) -> datasets.arrow_dataset.Dataset
Dataset({
    features: ['text'],
    num_rows: 5400
})

I’m running the previous script in Amazon SageMaker (Training Job) by using an ml.g5.12xlarge.

My questions:

  1. The selected instance type is has 4 GPUs. How can I check if the training is distributed across the available GPUs?
  2. I see the following warning in the logs:
warnings.warn(
/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '

How can I solve it?

Thank you!

A quick way would be torch.cuda.device_count() and combine that with torch.cuda.memory_stats(index) passing in the trainer.args.distributed_state.process_index

We’d need to know your data and you should look at your data to see how the outputs look like.

FWIW, logging to wandb is the best way to make these two easy