Finetuning LLM(e.g Mistral-7B) on multiple CPUs with (Q)LoRa

Hi!
I am trying to finetune Mistral-7B-v0.1 on multiple CPUs with LoRa. I was thinking about using accelerate, peft and transformer library for that.
I wanted to add accelerate in the Training loop in the backward loss. But I have troubles giving my batch to Mistral.

Does someone has a solution for that? I also tried the really terrible and buggy Intel-library where I didn’t succeeded. So overall my goal would be:
Finetuning Mistral-7B (or any other LLM) with (Q)LoRa with multiple CPUs.

BASIC-CODE:
import torch
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer
)
from trl import SFTTrainer
from transformers import AdamW, get_linear_schedule_with_warmup```

model_name = “mistralai/Mistral-7B-v0.1”
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True, use_fast=True)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = ‘left’

model = AutoModelForCausalLM.from_pretrained(
model_name, load_in_4bit=True, use_llm_runtime=False, torch_dtype=torch.float32, low_cpu_mem_usage=False
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = True)
#Configure the pad token in the model
model.config.pad_token_id = tokenizer.pad_token_id

model.gradient_checkpointing_enable()
peft_config = LoraConfig(
r=8,
task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, peft_config)

def format_ultrachat(ds):
text =
for row in ds:
if len(row[‘messages’]) > 2:
text.append("### Human: “+row[‘messages’][0][‘content’]+”### Assistant: “+row[‘messages’][1][‘content’]+”### Human: “+row[‘messages’][2][‘content’]+”### Assistant: “+row[‘messages’][3][‘content’])
else: #not all tialogues have more than one turn
text.append(”### Human: “+row[‘messages’][0][‘content’]+”### Assistant: "+row[‘messages’][1][‘content’])
ds = ds.add_column(name=“text”, column=text)
return ds
dataset_train_sft = load_dataset(“HuggingFaceH4/ultrachat_200k”, split=“train_sft”)
dataset_test_sft = load_dataset(“HuggingFaceH4/ultrachat_200k”, split=“test_sft[:5%]”)

dataset_test_sft = format_ultrachat(dataset_test_sft)
dataset_train_sft = format_ultrachat(dataset_train_sft)

####THIS SHOULD BE TAKEN AWAY AND CHANGED WITH MY OWN LOOPE#####
training_arguments = TrainingArguments(
output_dir=“./results_sft_cpu/”,
#evaluation_strategy=“steps”,
#do_eval=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
per_device_eval_batch_size=4,
log_level=“debug”,
save_steps=20,
logging_steps=10,
learning_rate=1e-4,
#eval_steps=10,
max_steps=200,
warmup_steps=20,
lr_scheduler_type=“linear”,
)

trainer = SFTTrainer(
model=model,
train_dataset=dataset_train_sft,
#\val_dataset=dataset_test_sft,
dataset_text_field=“text”,
max_seq_length=256,
tokenizer=tokenizer,
args=training_arguments,
)

trainer.train()
#######END###########


Goale:

[…]

[prepare all the data and the model with accelerate]
######CHANGED CODE ######
num_train_epochs = 2
learning_rate = 1e-4
nbr_warmup = 0
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = len(dataset_train_sft) * num_train_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=nbr_warmup, num_training_steps=num_training_steps)

model.train()
for epoch in range(num_train_epochs):
for step, batch in enumerate(dataset_train_sft):
output = model(**batch)
[back loss etc. with accelerate]


The error I get for now is:
 MistralForCausalLM.forward() got an unexpected keyword argument 'prompt'

THis is the content of the Mistral dict/batch:

model.train()
for epoch in range(num_train_epochs):
for step, batch in enumerate(dataset_train_sft):
for x in batch:
print(x)
break


> prompt
> prompt_id
> messages
> text


Is this possible at the moment? 
With the Intel stuff I was able to train on one CPU but I want to use it on multiple.

Help would be much appreciated.