Model not using all attention layers while inferencing on device_map="auto"

here in my code I load Qwen/Qwen3-4B as AutoModelForCausalLM with device_map=”auto”. when i set a hook on all the modules using model.named_modules() it sets hooks on all layers. (I verified it by printing the name). but when i print the name on the activation_hook only the layers from 0-15 attn/mlp layers are getting printed. when i change device_map=”auto” to “cuda:0” then it print’s all the layers name in activation hook from 0-35. any explaination for this would be help. The machine I am using has 2 A100 GPU. Please consider me a beginner. Please ask me if you need any more details on the issue i am facing.

1 Like

device_map="auto" setting in the Accelerate library has quirks.


Explanation: device_map="auto" shards whole layers across your two GPUs. Hooks still run on every layer, but your hook then mixes tensors from different devices (e.g., output on cuda:1 with self.unembed / self.answers on cuda:0). PyTorch requires all tensors in an op to be on the same device, so the hook throws a device-mismatch inside the callback. Your except: pass swallows the error, so you only see names from layers that live on the same GPU as your extra tensors. Forcing device_map="cuda:0" keeps everything on one GPU, so all layers print. Qwen3-4B is ~36 blocks, so a ~0–15 vs 16–35 split is typical on 2×A100. (Hugging Face) (Hugging Face) (PyTorch Forums) (Hugging Face)

Beginner-safe fixes:

  1. Inspect the split.
print(model.hf_device_map)  # shows layer -> device
# docs: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference

(Hugging Face)

  1. Make the hook device-aware. Move everything you use inside the hook to the layer’s device (output.device). Also stop swallowing errors so you can see them.
def _activation_hook(self, module, _input, output, name):
    print(name)  # print first; always visible even if later ops fail
    try:
        import torch
        with torch.no_grad():
            # get the device of this layer’s output
            dev = output.device if torch.is_tensor(output) else output[0].device

            # cache one unembed per device to avoid repeated copies
            if not hasattr(self, "unembed_cache"):
                self.unembed_cache = {}
            if dev not in self.unembed_cache:
                # ref: device maps and dispatch
                # https://huggingface.co/docs/accelerate/en/usage_guides/big_modeling
                self.unembed_cache[dev] = self.unembed.to(dev, non_blocking=True)
            unembed = self.unembed_cache[dev]

            # move other inputs used by the hook
            answers = self.answers.to(dev, non_blocking=True)
            last_hidden = output[:, -1, :].to(dev, non_blocking=True)

            # do your per-layer computation on a single device
            # PyTorch requires same-device tensors:
            # https://discuss.pytorch.org/t/model-parallelize-runtimeerror-expected-all-tensors-to-be-on-the-same-device-but-found-at-least-two-devices-cuda-0-and-cuda-1/172673
            logits = torch.einsum("bi,bi->b",
                                  unembed[answers.squeeze(-1), :],
                                  last_hidden)

            self.logit_dict[name].extend(logits.detach().cpu().tolist())
    except Exception as e:
        print(f"[hook {name}] {e}")  # show real errors instead of silent pass
    return output

(Hugging Face)

  1. If you want zero device copies, only hook layers that already live on the same device as your reference tensor (here unembed). You’ll see fewer layers but no transfers.
# keep hooks only where module parameters share the unembed’s device
unembed_dev = (next(self.unembed.parameters()).device
               if hasattr(self.unembed, "parameters") else self.unembed.device)

for name, module in model.named_modules():
    try:
        mod_dev = next(module.parameters()).device
    except StopIteration:
        continue
    if mod_dev == unembed_dev:
        module.register_forward_hook(lambda m,i,o, n=name: self._activation_hook(m,i,o,n))
# hook API:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook

(PyTorch Documentation)

  1. Lightweight “sanity” hook that never mixes devices. Good for first checks.
def meta_hook(module, _in, out, name):
    # minimal, no tensor math; just log device and shape
    dev = out.device if torch.is_tensor(out) else out[0].device
    shape = tuple(out.shape) if torch.is_tensor(out) else tuple(out[0].shape)
    print(f"{name} -> device={dev}, shape={shape}")

for name, module in model.named_modules():
    module.register_forward_hook(lambda m,i,o, n=name: meta_hook(m,i,o,n))
# forward hooks overview:
# https://discuss.pytorch.org/t/how-to-register-forward-hooks-for-each-module/43347

(PyTorch Forums)

Why this explains your printout:

  • Accelerate spreads layers across GPUs and sometimes CPU/disk; you must operate per-device inside the hook. (Hugging Face)
  • model.hf_device_map shows the exact layer→device mapping; you’ll likely see early blocks on cuda:0 and later ones on cuda:1. (Hugging Face)
  • Qwen3-4B has 36 transformer layers, so seeing only 0–15 when your extra tensors sit on cuda:0 fits a ~half split. (Hugging Face)
  • PyTorch ops require single-device inputs; silent except: pass hid the error. Remove it or print the exception. (PyTorch Forums)

Tip: If you ever need a strict single-GPU run to avoid all this, load with device_map="cuda:0". If you want multi-GPU sharding, keep device_map="auto" and write hooks that are purely local to the current layer’s device.

Further reading (short, high-quality):

  • Accelerate big-model inference: device maps, sharding strategy, and hf_device_map. Good conceptual background and practical examples. (Hugging Face)
  • PyTorch forward hooks API: exact semantics of register_forward_hook. Useful for correct hook signatures and removal. (PyTorch Documentation)
  • Same-device requirement: typical error and fixes in multi-GPU model-parallel setups. Useful when a hook mixes devices. (PyTorch Forums)
  • Qwen3-4B model card: quick spec confirmation that you’re dealing with 36 layers. (Hugging Face)
  • Real-world device_map="auto" on 2Ă—A100: confirms multi-GPU placement happens as expected. (GitHub)

This will make your hooks reliable with device_map="auto" and keep the code easy to reason about.

Thanks a lot sir. Deeply grateful for your reply. I added the try catch because I added hooks on all the layers. and some layers outputs would throw size mismatch error (which was expected). That’s why i missed the device mismatch error. the try catch was very stupid of me. Thanks again for you time.

1 Like