How to properly wrap a model for training with accelerate?

Hello!

import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
import torch
import transformers


base_model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-1.3b",
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b", truncation_side = "right", padding_side = "right")

class BypassCausalLM(nn.Module):
    def __init__(self, base_model: AutoModelForCausalLM):
        super(BypassCausalLM, self).__init__()

        # Use the provided pre-initialized model
        self.model = base_model

    def forward(self, input_ids, attention_mask=None, **kwargs):
        return self.model(input_ids, attention_mask=attention_mask, **kwargs)
    
    def resize_token_embeddings(self, new_num_tokens):
        return self.model.resize_token_embeddings(new_num_tokens)

model = BypassCausalLM(base_model)

from datasets import load_dataset

dataset = load_dataset("lambada")
data = dataset.map(lambda sample: tokenizer(sample["text"], padding=False, truncation=True, max_length=512), num_proc=4)

trainer = transformers.Trainer(
    model=model,
    train_dataset=data['train'],
    eval_dataset=data['validation'],
    args=transformers.TrainingArguments(
        do_train=True,
        do_eval=True,
        auto_find_batch_size=True,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=50,
        num_train_epochs=1,
        learning_rate=1e-6,
        logging_steps=1,
        output_dir='outputs',
        eval_steps=200,
        evaluation_strategy="steps",
        label_names=["labels"],
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
torch.cuda.empty_cache()
trainer.train()
print(trainer.evaluate())

Here is a minimal example of a model wrapper, which causes this error:

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 27
      4 trainer = transformers.Trainer(
      5     model=model,
      6     train_dataset=data['train'],
   (...)
     24     data_collator=transformers.DataCollatorWithPadding(tokenizer, padding=True)
     25 )
     26 torch.cuda.empty_cache()
---> 27 trainer.train()
     28 print(trainer.evaluate())

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1585, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1583         hf_hub_utils.enable_progress_bars()
   1584 else:
-> 1585     return inner_training_loop(
   1586         args=args,
   1587         resume_from_checkpoint=resume_from_checkpoint,
   1588         trial=trial,
   1589         ignore_keys_for_eval=ignore_keys_for_eval,
   1590     )

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/memory.py:136, in find_executable_batch_size.<locals>.decorator(*args, **kwargs)
    134     raise RuntimeError("No executable batch size found, reached zero.")
    135 try:
--> 136     return function(batch_size, *args, **kwargs)
    137 except Exception as e:
    138     if should_reduce_batch_size(e):

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1885, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1882     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1884 with self.accelerator.accumulate(model):
-> 1885     tr_loss_step = self.training_step(model, inputs)
   1887 if (
   1888     args.logging_nan_inf_filter
   1889     and not is_torch_tpu_available()
   1890     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1891 ):
   1892     # if loss is nan or inf simply add the average of previous logged losses
   1893     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2750, in Trainer.training_step(self, model, inputs)
   2747     return loss_mb.reduce_mean().detach().to(self.args.device)
   2749 with self.compute_loss_context_manager():
-> 2750     loss = self.compute_loss(model, inputs)
   2752 if self.args.n_gpu > 1:
   2753     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2775, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2773 else:
   2774     labels = None
-> 2775 outputs = model(**inputs)
   2776 # Save past state if it exists
   2777 # TODO: this needs to be fixed and made cleaner later.
   2778 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, **kwargs)
    169     return self.module(*inputs[0], **kwargs[0])
    170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 171 outputs = self.parallel_apply(replicas, inputs, kwargs)
    172 return self.gather(outputs, self.output_device)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
    180 def parallel_apply(self, replicas, inputs, kwargs):
--> 181     return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices)
     87     output = results[i]
     88     if isinstance(output, ExceptionWrapper):
---> 89         output.reraise()
     90     outputs.append(output)
     91 return outputs

File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1997/2704660530.py", line 20, in forward
    return self.model(input_ids, attention_mask=attention_mask, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 944, in forward
    outputs = self.model.decoder(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 710, in forward
    layer_outputs = decoder_layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 327, in forward
    hidden_states = self.self_attn_layer_norm(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 190, in forward
    return F.layer_norm(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2515, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)

when trying to train a model initialized with device_map=“auto” via Transformers.Trainer.
What is the proper way to wrap a model to be trained with accelerate support?
NOTE: I need this wrapper to be present, this is just a minimal example.

I have managed to solve it myself:

import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import PreTrainedModel, PretrainedConfig

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

class BypassCausalLM(nn.Module):
    def __init__(self, base_model: AutoModelForCausalLM):
        super(BypassCausalLM, self).__init__()

        # Use the provided pre-initialized model
        self.model = base_model

    def forward(self, input_ids, attention_mask=None, **kwargs):
        return self.model(input_ids, attention_mask=attention_mask, **kwargs)
    
    def resize_token_embeddings(self, new_num_tokens):
        return self.model.resize_token_embeddings(new_num_tokens)

model = BypassCausalLM(base_model)
max_memory = get_balanced_memory(
    model,
    max_memory=None,
    no_split_module_classes=base_model._no_split_modules,
    low_zero=False,
)

device_map = infer_auto_device_map(
    model,
    max_memory=max_memory,
    no_split_module_classes=base_model._no_split_modules,
)

model = dispatch_model(model, device_map=device_map)