Iβm trying to fine-tune Llama-2-13b-chat for machine translation with lightning trainer and deepspeed. The test reported here is on 4 rtx 3090 with an 8bit quantized version of the model and LoRa and I get this error. This is the code used for training
trainer = L.Trainer(
devices=4,
accelerator="gpu",
#strategy="ddp",
strategy=DeepSpeedStrategy(
stage=3,
offload_optimizer=True,
offload_parameters=True,
),
max_steps=10,
#precision="16-mixed",
enable_checkpointing=True,
accumulate_grad_batches=args.accumulate_grad_batches,
log_every_n_steps=20,
val_check_interval=30,
limit_val_batches=200,
#default_root_dir="checkpoints",
callbacks=[checkpoint_callback],
gradient_clip_val=args.gradient_clip_val,
logger=mlflow_logger,
)
print(args.quantized)
print(quantization_config)
model = MTModel(model_name=args.model_name, pad_token_id=tokenizer.pad_token_id, inference=False,
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
betas=args.betas, quantization_config=quantization_config,peft=args.peft, load_in_8bit=False)
trainer.fit(model, train_data_loader, test_data_loader)
this is MTModel
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_quant_type="nf8",
bnb_8bit_use_double_quant=True,
)
class MTModel(L.LightningModule):
def __init__(self, model_name: str, pad_token_id: int,inference: bool,
learning_rate: float = 1e-4, weight_decay: float = 0.0,
betas: tuple = (0.9, 0.95), quantization_config:LoraConfig = None,
load_in_8bit:bool=False,peft:bool=False):
super().__init__()
self.save_hyperparameters()
self.model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,device_map="auto",
quantization_config=quantization_config,load_in_8bit=load_in_8bit)
if peft:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=inference, r=4, lora_alpha=4, lora_dropout=0.01)
self.model = get_peft_model(self.model, peft_config)
def training_step(self, batch, batch_idx):
get_accelerator().empty_cache()
input_ids, target_start_idx = batch
logits = self.model(input_ids).logits
loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
input_ids, target_start_idx = batch
logits = self.model(input_ids).logits
loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
self.log("val_loss", loss, sync_dist=True)
self.log("val_ppl", torch.exp(loss), sync_dist=True)
return loss
def generate(self, batch, **kwargs):
return self.model.generate(batch,**kwargs)
def configure_optimizers(self):
return DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, betas=self.hparams.betas)
and here is the error
β
β β
β /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/functional.py:2235 in embedding β
β β
β 2232 β β # torch.embedding_renorm_ β
β 2233 β β # remove once script supports set_grad_enabled β
β 2234 β β _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) β
β β± 2235 β return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) β
β 2236 β
β 2237 β
β 2238 def embedding_bag( β
β°βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ―
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
I found one similar similar issue but I get the error also without load_in_8bit=True and anyway in my case the model dispatch should be handled by DeepSpeed. Has anyone encountered similar errors?