I am trying to train an (facebook/opt) LLM for Causal Language Modeling with a custom implementation of SVD decomposition of the Linear Layer weights. I have 2 GPUSs BTW.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import DataLoader
class SVDLinear(nn.Module):
def __init__(self, linear_module, r_ratio=1.0):
super().__init__()
self.in_features = linear_module.in_features
self.out_features = linear_module.out_features
self.r_ratio = r_ratio
# Perform standard SVD (unweighted)
D, B = self.get_unweighted_svd(linear_module.weight.data.clone().t())
# Register D and B as parameters
self.D = nn.Parameter(D)
self.B = nn.Parameter(B)
self.bias = linear_module.bias # Keep the original bias
def get_unweighted_svd(self, W):
with torch.no_grad():
U, S, V = torch.svd(W)
return self.truncate_svd(U, S, V)
def truncate_svd(self, U, S, V):
k = math.ceil(S.size(0) * self.r_ratio) # Number of singular values to keep
U_truncated = U[:, :k]
S_truncated = S[:k]
V_truncated = V[:, :k]
D = U_truncated @ torch.diag(S_truncated)
B = V_truncated.t()
return D, B
def forward(self, x):
S = F.linear(x, self.D.t()) # First linear transformation using D
return F.linear(S, self.B.t(), self.bias) # Second linear transformation using B
# Step 1: Load Pretrained Model
def load_pretrained_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
return model
# Step 2: Load the Dataset with Proper Tokenization and DataLoader
def load_wikitext_dataset(model_name, batch_size=2):
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # Make sure the tokenizer uses EOS token for padding
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
def tokenize_function(examples):
# Tokenize the text and prepare input_ids and labels for causal language modeling
return tokenizer(
examples['text'],
padding="max_length", # Ensure padding to a fixed length
truncation=True, # Truncate texts that are too long
max_length=512, # Set maximum sequence length
return_tensors="pt" # Return PyTorch tensors
)
# Tokenize the dataset and remove the 'text' column
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['text'])
train_dataset = tokenized_dataset['train']
validation_dataset = tokenized_dataset['validation']
test_dataset = tokenized_dataset['test']
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=data_collator)
return train_loader, validation_loader, test_loader
# Step 3: Replace Linear Layers with Decomposed SVDLinear Layers
def replace_linear_with_svdlinear(model, r_ratio=1.0):
modules_to_replace = []
# Collect nn.Linear modules that need to be replaced with SVDLinear
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
modules_to_replace.append((name, module))
# Replace the collected modules
for name, module in modules_to_replace:
svd_linear = SVDLinear(module, r_ratio=r_ratio)
# Navigate through the parent module hierarchy
*parent_path, last_name = name.split('.')
parent_module = model
for part in parent_path:
parent_module = getattr(parent_module, part)
# Replace the nn.Linear with the new SVDLinear module
setattr(parent_module, last_name, svd_linear)
return model
# Step 4: Hugging Face Trainer Setup with Arguments Similar to Your Implementation
def train_model(model, train_loader, validation_loader):
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=1,
learning_rate=5e-5,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=1000,
save_strategy="no", # Disable saving checkpoints
eval_steps=5000,
max_steps=100, # Limit the number of steps as per your use case
report_to="none" # Set to none if you are not using W&B
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_loader.dataset,
eval_dataset=validation_loader.dataset,
data_collator=train_loader.collate_fn
)
# Start Training
trainer.train()
# Putting it All Together
if __name__ == "__main__":
model_name = "facebook/opt-350m" # You can change this to another model if needed
# Load pretrained model
model = load_pretrained_model(model_name)
# Load the dataset with proper tokenization and DataLoader setup
train_loader, validation_loader, test_loader = load_wikitext_dataset(model_name)
# Decompose and replace linear layers with SVDLinear
model = replace_linear_with_svdlinear(model, r_ratio=0.5) # Replace with SVDLinear layers
# Train model using Hugging Face Trainer
train_model(model, train_loader, validation_loader)
This is the script that can generate the following error:
/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/training_args.py:1541: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
warnings.warn(
max_steps is given, it will override any value given in num_train_epochs
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "error.py", line 152, in <module>
train_model(model, train_loader, validation_loader)
File "error.py", line 135, in train_model
trainer.train()
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 2022, in train
return inner_training_loop(
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 2358, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 3455, in training_step
loss = self.compute_loss(model, inputs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/trainer.py", line 3502, in compute_loss
outputs = model(**inputs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py", line 1011, in forward
outputs = self.model.decoder(
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py", line 798, in forward
hidden_states = self.project_out(hidden_states)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/mayank/miniconda3/envs/llm_train/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "error.py", line 40, in forward
S = F.linear(x, self.D.t()) # First linear transformation using D
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
0%| | 0/100 [00:00<?, ?it/s]
I am not sure why this is happening. Also I noticed that if I switch to a smaller checkpoint like facebook/opt-125m
, I don’t get this error.
Any help would be greatly appreciated .