Device error while trying to train with custom linear layer on multi-gpu

I am trying to train an (facebook/opt) LLM for Causal Language Modeling with a custom implementation of SVD decomposition of the Linear Layer weights. I have 2 GPUSs BTW.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import DataLoader

class SVDLinear(nn.Module):
    def __init__(self, linear_module, r_ratio=1.0):
        super().__init__()
        self.in_features = linear_module.in_features
        self.out_features = linear_module.out_features
        self.r_ratio = r_ratio

        # Perform standard SVD (unweighted)
        D, B = self.get_unweighted_svd(linear_module.weight.data.clone().t())

        # Register D and B as parameters
        self.D = nn.Parameter(D)
        self.B = nn.Parameter(B)
        self.bias = linear_module.bias  # Keep the original bias

    def get_unweighted_svd(self, W):
        with torch.no_grad():
            U, S, V = torch.svd(W)
            return self.truncate_svd(U, S, V)

    def truncate_svd(self, U, S, V):
        k = math.ceil(S.size(0) * self.r_ratio)  # Number of singular values to keep
        U_truncated = U[:, :k]
        S_truncated = S[:k]
        V_truncated = V[:, :k]

        D = U_truncated @ torch.diag(S_truncated)
        B = V_truncated.t()
        return D, B

    def forward(self, x):
        S = F.linear(x, self.D.t())  # First linear transformation using D
        return F.linear(S, self.B.t(), self.bias)  # Second linear transformation using B


# Step 1: Load Pretrained Model
def load_pretrained_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    return model

# Step 2: Load the Dataset with Proper Tokenization and DataLoader
def load_wikitext_dataset(model_name, batch_size=2):
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token  # Make sure the tokenizer uses EOS token for padding

    # Data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    def tokenize_function(examples):
        # Tokenize the text and prepare input_ids and labels for causal language modeling
        return tokenizer(
            examples['text'],
            padding="max_length",  # Ensure padding to a fixed length
            truncation=True,       # Truncate texts that are too long
            max_length=512,        # Set maximum sequence length
            return_tensors="pt"    # Return PyTorch tensors
        )

    # Tokenize the dataset and remove the 'text' column
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns(['text'])

    train_dataset = tokenized_dataset['train']
    validation_dataset = tokenized_dataset['validation']
    test_dataset = tokenized_dataset['test']

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, collate_fn=data_collator)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=data_collator)

    return train_loader, validation_loader, test_loader

# Step 3: Replace Linear Layers with Decomposed SVDLinear Layers
def replace_linear_with_svdlinear(model, r_ratio=1.0):
    modules_to_replace = []

    # Collect nn.Linear modules that need to be replaced with SVDLinear
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            modules_to_replace.append((name, module))

    # Replace the collected modules
    for name, module in modules_to_replace:
        svd_linear = SVDLinear(module, r_ratio=r_ratio)

        # Navigate through the parent module hierarchy
        *parent_path, last_name = name.split('.')
        parent_module = model
        for part in parent_path:
            parent_module = getattr(parent_module, part)

        # Replace the nn.Linear with the new SVDLinear module
        setattr(parent_module, last_name, svd_linear)

    return model

# Step 4: Hugging Face Trainer Setup with Arguments Similar to Your Implementation
def train_model(model, train_loader, validation_loader):
    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="steps",
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_train_epochs=1,
        learning_rate=5e-5,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=1000,
        save_strategy="no",  # Disable saving checkpoints
        eval_steps=5000,
        max_steps=100,  # Limit the number of steps as per your use case
        report_to="none"  # Set to none if you are not using W&B
    )
    
    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_loader.dataset,
        eval_dataset=validation_loader.dataset,
        data_collator=train_loader.collate_fn
    )
    
    # Start Training
    trainer.train()

# Putting it All Together
if __name__ == "__main__":
    model_name = "facebook/opt-350m"  # You can change this to another model if needed
    
    # Load pretrained model
    model = load_pretrained_model(model_name)
    
    # Load the dataset with proper tokenization and DataLoader setup
    train_loader, validation_loader, test_loader = load_wikitext_dataset(model_name)


    # Decompose and replace linear layers with SVDLinear
    model = replace_linear_with_svdlinear(model, r_ratio=0.5)  # Replace with SVDLinear layers

    # Train model using Hugging Face Trainer
    train_model(model, train_loader, validation_loader)

This is the script that can generate the following error:

/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/training_args.py:1541: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
max_steps is given, it will override any value given in num_train_epochs
  0%|                                                                                                                                                                                                                                                  | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "error.py", line 152, in <module>
    train_model(model, train_loader, validation_loader)
  File "error.py", line 135, in train_model
    trainer.train()
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 2022, in train
    return inner_training_loop(
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 2358, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 3455, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 3502, in compute_loss
    outputs = model(**inputs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py", line 1011, in forward
    outputs = self.model.decoder(
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py", line 798, in forward
    hidden_states = self.project_out(hidden_states)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "error.py", line 40, in forward
    S = F.linear(x, self.D.t())  # First linear transformation using D
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
  0%|                                                                                                                                                                                                                                                  | 0/100 [00:00<?, ?it/s]

I am not sure why this is happening. Also I noticed that if I switch to a smaller checkpoint like facebook/opt-125m , I don’t get this error.

Any help would be greatly appreciated :pray: .

1 Like