Issue with loading LLM for Text Classification in 8bit

When I try to load EleutherAI/gpt-neox-20b for a sequence classification task using AutoModelForSequenceClassification, I keep getting an error ValueError: weight is on the meta device, we need a value to put in on 0..

I assume it has to do with linear layer weights that are added for the classification task. Anyone know how to fix this issue?

Full error:

Some weights of the model checkpoint at EleutherAI/gpt-neox-20b were not used when initializing GPTNeoXForSequenceClassification: ['embed_out.weight']
- This IS expected if you are initializing GPTNeoXForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTNeoXForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
ValueError                                Traceback (most recent call last)
Cell In[8], line 1
----> 1 model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6, problem_type="multi_label_classification", 
      2                                                          load_in_8bit=True, device_map="auto", id2label=id2label, label2id=label2id)

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/transformers/models/auto/, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    469 elif type(config) in cls._model_mapping.keys():
    470     model_class = _get_model_class(config, cls._model_mapping)
--> 471     return model_class.from_pretrained(
    472         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    473     )
    474 raise ValueError(
    475     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    476     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    477 )

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/transformers/, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2844 # Dispatch model with hooks on all devices if necessary
   2845 if device_map is not None:
-> 2846     dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
   2848 if output_loading_info:
   2849     if loading_info is None:

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/accelerate/, in dispatch_model(model, device_map, main_device, state_dict, offload_dir, offload_index, offload_buffers, preload_module_classes)
    367     weights_map = None
    369 tied_params = find_tied_parameters(model)
--> 370 attach_align_device_hook_on_blocks(
    371     model,
    372     execution_device=execution_device,
    373     offload=offload,
    374     offload_buffers=offload_buffers,
    375     weights_map=weights_map,
    376     preload_module_classes=preload_module_classes,
    377 )
    378 # Attaching the hook may break tied weights, so we retie them
    379 retie_parameters(model, tied_params)

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/accelerate/, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, preload_module_classes)
    471 if module_name in execution_device and module_name in offload and not offload[module_name]:
    472     hook = AlignDevicesHook(
    473         execution_device=execution_device[module_name],
    474         offload_buffers=offload_buffers,
    475         io_same_device=(module_name == ""),
    476         place_submodules=True,
    477     )
--> 478     add_hook_to_module(module, hook)
    479     attach_execution_device_hook(module, execution_device[module_name])
    480 elif module_name in execution_device and module_name in offload:

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/accelerate/, in add_hook_to_module(module, hook, append)
    152     old_forward = module.forward
    153     module._old_forward = old_forward
--> 155 module = hook.init_hook(module)
    156 module._hf_hook = hook
    158 @functools.wraps(old_forward)
    159 def new_forward(*args, **kwargs):

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/accelerate/, in AlignDevicesHook.init_hook(self, module)
    249 if not self.offload and self.execution_device is not None:
    250     for name, _ in named_module_tensors(module, recurse=self.place_submodules):
--> 251         set_module_tensor_to_device(module, name, self.execution_device)
    252 elif self.offload:
    253     self.original_devices = {
    254         name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
    255     }

File ~/research/conda/envs/mpeft/lib/python3.9/site-packages/accelerate/utils/, in set_module_tensor_to_device(module, tensor_name, device, value, dtype)
    133 old_value = getattr(module, tensor_name)
    135 if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
--> 136     raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
    138 if value is not None:
    139     if dtype is None:
    140         # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model

ValueError: weight is on the meta device, we need a `value` to put in on 0.