Multi-GPU support lost when overwriting functions for Custom Trainer

Hello, I’m trying to implement the data2vec model with HuggingFace.

To this end, I’ve implemented a HuggingFace model and a Trainer as the following:

The custom trainer:

class Data2VecTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):

        if isinstance(model, nn.DataParallel):
            loss = model.module.compute_loss(inputs)
            loss = model.compute_loss(inputs)

        return loss 

    def training_step(self, model, inputs): 
        # regular step for updating model 
        loss = super().training_step(model, inputs) 

        # after model update, take a EMA step for teacher network 
        if isinstance(model, nn.DataParallel):
            if model.module.teacher_network is not None:
            if model.teacher_network is not None: 
        return loss 

The model:

class Data2VecTextEncoder(RobertaModel): 
    # just change the parent class to BertModel if testing with BERT 

    def __init__(self, config):"Entered custom initialization.")

        self.cfg = Data2VecTextConfig

        self.loss_scale = self.cfg.loss_scale

        assert self.cfg.head_layers >= 1

        self.average_top_k_layers = self.cfg.average_top_k_layers

        embed_dim = self.embeddings.word_embeddings.embedding_dim
        curr_dim = embed_dim
        projs = []
        for i in range(self.cfg.head_layers - 1):
            next_dim = embed_dim * 2 if i == 0 else curr_dim
            projs.append(nn.Linear(curr_dim, next_dim))
            curr_dim = next_dim

        projs.append(nn.Linear(curr_dim, embed_dim))
        self.regression_head = nn.Sequential(*projs)

        self.num_updates = 0

    def make_ema_teacher(self):
        ema_config = EMAModuleConfig(
        skip_keys = set()
        if self.cfg.ema_transformer_layers_only:
            # import pdb; pdb.set_trace()
            for k, _ in self.embeddings.named_parameters():
            for k, _ in self.pooler.named_parameters():

        self.teacher_network = EMAModule(

    def take_ema_step(self): 

        if self.cfg.ema_decay != self.cfg.ema_end_decay:
            if self.num_updates >= self.cfg.ema_anneal_end_step:
                decay = self.cfg.ema_end_decay
                decay = get_annealed_rate(
        if self.teacher_network.get_decay() < 1:

    def compute_loss(self, inputs): 

        # send to forward pass 
        student_output = self(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True)

        with torch.no_grad(): 
            teacher_output = self.teacher_network.model(inputs['original_input_ids'], inputs['attention_mask'], output_hidden_states=True)

        un_masked_indices = inputs.get("labels").eq(-100)

        student_features = student_output.get("last_hidden_state") # [8, 512, 768] = (Data2Vec:508 inner_states)
        student_inner_states = student_output.get("hidden_states")[1:] # [[8, 512, 768] * 12] = (Data2Vec:508 inner_states)
        student_encoder_embedding = student_output.get("hidden_states")[0] # [8, 512, 768] = (Data2Vec:509 encoder_embedding)

        teacher_features = teacher_output.get("last_hidden_state") # [8, 512, 768] = (Data2Vec:508 inner_states)
        teacher_inner_states = teacher_output.get("hidden_states")[1:] # [[8, 512, 768] * 12] = (Data2Vec:508 inner_states)
        teacher_encoder_embedding = teacher_output.get("hidden_states")[0] # [8, 512, 768] = (Data2Vec:509 encoder_embedding)

        y = teacher_inner_states[-self.cfg.average_top_k_layers :]

        # import pdb; pdb.set_trace()
        permuted = False
        if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
            y = [tl.permute(1, 2, 0) for tl in y]  # TBC -> BCT
            permuted = True

        if self.cfg.batch_norm_target_layer:
            y = [
                    tl.float(), running_mean=None, running_var=None, training=True
                for tl in y
        if self.cfg.instance_norm_target_layer:
            y = [F.instance_norm(tl.float()) for tl in y]

        if permuted:
            y = [tl.transpose(1, 2) for tl in y]  # BCT -> BTC

        if self.cfg.layer_norm_target_layer:
            y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]

        y = sum(y) / len(y)

        ### This part leads to dimensional mismatch 
        # if not permuted:
        #     y = y.transpose(0, 1)

        if self.cfg.layer_norm_targets:
            y = F.layer_norm(y.float(), y.shape[-1:])

        if self.cfg.instance_norm_targets:
            y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)

        x = student_features
        x = x[~un_masked_indices]
        y= y[~un_masked_indices]

        x = self.regression_head(x)

        sz = x.size(-1)
        if self.cfg.loss_beta == 0:
            loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
            loss = F.smooth_l1_loss(
                x.float(), y.float(), reduction="none", beta=self.cfg.loss_beta

        loss = loss.sum() / math.sqrt(sz)

        return loss

The Exponential Moving Average model

class EMAModule:
    """Exponential Moving Average of Fairseq Models"""

    def __init__(self, model, config: EMAModuleConfig, device="cuda" if torch.cuda.is_available() else "cpu", skip_keys=None):
        @param model model to initialize the EMA with
        @param config EMAConfig object with configuration like
        ema_decay, ema_update_freq, ema_fp32
        @param device If provided, copy EMA to this device (e.g. gpu).
        Otherwise EMA is in the same device as the model.

        self.decay = config.ema_decay
        self.model = deepcopy(model)
        self.config = config
        self.skip_keys = skip_keys or set()
        self.fp32_params = {}

        if device is not None:
  "Copying EMA model to device {device}")
            self.model =

        if self.config.ema_fp32:

        self.update_freq_counter = 0

    def build_fp32_params(self, state_dict=None):
        Store a copy of the EMA params in fp32.
        If state dict is passed, the EMA params is copied from
        the provided state dict. Otherwise, it is copied from the
        current EMA model parameters.
        if not self.config.ema_fp32:
            raise RuntimeError(
                "build_fp32_params should not be called if ema_fp32=False. "
                "Use ema_fp32=True if this is really intended."

        if state_dict is None:
            state_dict = self.model.state_dict()

        def _to_float(t):
            return t.float() if torch.is_floating_point(t) else t

        for param_key in state_dict:
            if param_key in self.fp32_params:
                self.fp32_params[param_key] = _to_float(state_dict[param_key])

    def restore(self, state_dict, build_fp32_params=False):
        """Load data from a model spec into EMA model"""
        self.model.load_state_dict(state_dict, strict=False)
        if build_fp32_params:

    def set_decay(self, decay):
        self.decay = decay

    def get_decay(self):
        return self.decay

    def _step_internal(self, new_model):
        """One update of the EMA model based on new model weights"""
        decay = self.decay

        ema_state_dict = {}
        ema_params = (
            self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
        for key, param in new_model.state_dict().items():
            if isinstance(param, dict):
                ema_param = ema_params[key]
            except KeyError:
                ema_param = (
                    param.float().clone() if param.ndim == 1 else deepcopy(param)

            if param.shape != ema_param.shape:
                raise ValueError(
                    "incompatible tensor shapes between model param and ema param"
                    + "{} vs. {}".format(param.shape, ema_param.shape)

            if "version" in key:
                # Do not decay a model.version pytorch param

            if key in self.skip_keys:
                ema_param =
                ema_param.add_(, alpha=1 - decay)
            ema_state_dict[key] = ema_param
        self.restore(ema_state_dict, build_fp32_params=False)

    def step(self, new_model):

    def reverse(self, model):
        Load the model parameters from EMA model.
        Useful for inference or fine-tuning from the EMA model.
        d = self.model.state_dict()
        if "_ema" in d:
            del d["_ema"]

        model.load_state_dict(d, strict=False)
        return model

I’ve adapted the training script here: transformers/examples/pytorch/language-modeling at master · huggingface/transformers · GitHub and was able to see that when I use I am able to use multiple GPUs without specifying mulit-gpu usage.

I am wondering if this is happening because of the way that I overwrote the compute_loss function in the Trainer. If I comment it out, I can see with nvidia-smi that the models are loaded to multiple gpus, but because of the way my forward function is defined, it will crash.

I would appreciate any guidance in what I’ve done wrong with the overwriting process. Thank you!