How can I replace modules in a pretrained model?

I am trying to implement mixout for a study I’m working on, as defined here. However, there seems to be an issue I’m running into when using the code based on what’s in example.py on a pre-trained model from huggingface.

Here’s a MWE, based on example.py from the linked repo, but using a pretrained model instead of a bespoke one:

import torch
from torch import nn
from transformers import AutoModelForMaskedLM

from module import MixLinear

model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

print(model)

for name, module in model.named_modules():
    if isinstance(module, nn.Dropout):
        setattr(model, name, nn.Dropout(0))
        assert getattr(model, name).p == 0, f'Dropout was not disabled for {name}!'
    elif isinstance(module, nn.Linear):
        target_state_dict   = module.state_dict()
        bias                = True if module.bias is not None else False
        new_module          = MixLinear(
                                module.in_features, 
                                module.out_features, 
                                bias,
                                target_state_dict['weight'],
                                0.9
                            )
        new_module.load_state_dict(target_state_dict)
        setattr(model, name, new_module)
        
        assert isinstance(getattr(model, name), MixLinear), f'{name} was not correctly changed to use mixout!'

print(model)

The first issue I run into is during the loop over the named_modules(), where I get this error:

Traceback (most recent call last):                                                                                            
  File "<stdin>", line 1, in <module>                                                                                         
  File "C:\Program Files\Python39\lib\site-packages\torch\nn\modules\module.py", line 1706, in named_modules                  
    for name, module in self._modules.items():                                                                                
RuntimeError: OrderedDict mutated during iteration

This error doesn’t show up again if I attempt to run through the loop again, or if I change for name, module in model.named_modules(): to for name, module in tuple(model.named_modules()):. So I’ve done that to proceed.

However, while the loop goes through everything without throwing an error when I make that change, the results don’t come out as expected. When I print the model before the loop, this is the output (as expected, and truncated to fit in the character limit for posts):

BertForMaskedLM(                                                                                                              
  (bert): BertModel(                                                                                                          
    (embeddings): BertEmbeddings(                                                                                             
      (word_embeddings): Embedding(30522, 768, padding_idx=0)                                                                 
      (position_embeddings): Embedding(512, 768)                                                                              
      (token_type_embeddings): Embedding(2, 768)                                                                              
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                      
      (dropout): Dropout(p=0.1, inplace=False)                                                                                
    )                                                                                                                         
    (encoder): BertEncoder(                                                                                                   
      (layer): ModuleList(                                                                                                    
        (0): BertLayer(                                                                                                       
          (attention): BertAttention(                                                                                         
            (self): BertSelfAttention(                                                                                        
              (query): Linear(in_features=768, out_features=768, bias=True)                                                   
              (key): Linear(in_features=768, out_features=768, bias=True)                                                     
              (value): Linear(in_features=768, out_features=768, bias=True)                                                   
              (dropout): Dropout(p=0.1, inplace=False)                                                                        
            )                                                                                                                 
            (output): BertSelfOutput(                                                                                         
              (dense): Linear(in_features=768, out_features=768, bias=True)                                                   
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                              
              (dropout): Dropout(p=0.1, inplace=False)                                                                        
            )                                                                                                                 
          )                                                                                                          
  
  [snip]         
                                                                           
  (cls): BertOnlyMLMHead(                                                                                                     
    (predictions): BertLMPredictionHead(                                                                                      
      (transform): BertPredictionHeadTransform(                                                                               
        (dense): Linear(in_features=768, out_features=768, bias=True)                                                         
        (transform_act_fn): GELUActivation()                                                                                  
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                    
      )                                                                                                                       
      (decoder): Linear(in_features=768, out_features=30522, bias=True)                                                       
    )                                                                                                                         
  )                                                                                                                           
)

When I print it after the loop, none of the model’s modules appear to have been replaced with the MixLinear layers, the dropout does not appear to have been set to 0, and instead new attributes have been added with the layers I want (this is also truncated to fit in the character limit, but the rest is the same in all the relevant ways).

BertForMaskedLM(                                                                                                              
  (bert): BertModel(                                                                                                          
    (embeddings): BertEmbeddings(                                                                                             
      (word_embeddings): Embedding(30522, 768, padding_idx=0)                                                                 
      (position_embeddings): Embedding(512, 768)                                                                              
      (token_type_embeddings): Embedding(2, 768)                                                                              
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                      
      (dropout): Dropout(p=0.1, inplace=False)                                                                                
    )                                                                                                                         
    (encoder): BertEncoder(                                                                                                   
      (layer): ModuleList(                                                                                                    
        (0): BertLayer(                                                                                                       
          (attention): BertAttention(                                                                                         
            (self): BertSelfAttention(                                                                                        
              (query): Linear(in_features=768, out_features=768, bias=True)                                                   
              (key): Linear(in_features=768, out_features=768, bias=True)                                                     
              (value): Linear(in_features=768, out_features=768, bias=True)                                                   
              (dropout): Dropout(p=0.1, inplace=False)                                                                        
            )                                                                                                                 
            (output): BertSelfOutput(                                                                                         
              (dense): Linear(in_features=768, out_features=768, bias=True)                                                   
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                              
              (dropout): Dropout(p=0.1, inplace=False)                                                                        
            )                                                                                                                 
          )                                                                                                                   
          (intermediate): BertIntermediate(                                                                                   
            (dense): Linear(in_features=768, out_features=3072, bias=True)                                                    
            (intermediate_act_fn): GELUActivation()                                                                           
          )                                                                                                                   
          (output): BertOutput(                                                                                               
            (dense): Linear(in_features=3072, out_features=768, bias=True)                                                    
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                
            (dropout): Dropout(p=0.1, inplace=False)                                                                          
          )                                                                                                                   
        )                                                                                                            
  
  [snip]  
                                                                              
  (cls): BertOnlyMLMHead(                                                                                                     
    (predictions): BertLMPredictionHead(                                                                                      
      (transform): BertPredictionHeadTransform(                                                                               
        (dense): Linear(in_features=768, out_features=768, bias=True)                                                         
        (transform_act_fn): GELUActivation()                                                                                  
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                    
      )                                                                                                                       
      (decoder): Linear(in_features=768, out_features=30522, bias=True)                                                       
    )                                                                                                                         
  )                                                                                                                           
  (bert.embeddings.dropout): Dropout(p=0, inplace=False)                                                                      
  (bert.encoder.layer.0.attention.self.query): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)            
  (bert.encoder.layer.0.attention.self.key): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)              
  (bert.encoder.layer.0.attention.self.value): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)            
  (bert.encoder.layer.0.attention.self.dropout): Dropout(p=0, inplace=False)                                                  
  
  [snip]
  
  (bert.encoder.layer.11.attention.output.dropout): Dropout(p=0, inplace=False)                                               
  (bert.encoder.layer.11.intermediate.dense): MixLinear(mixout=0.9, in_features=768, out_features=3072, bias=True)            
  (bert.encoder.layer.11.output.dense): MixLinear(mixout=0.9, in_features=3072, out_features=768, bias=True)                  
  (bert.encoder.layer.11.output.dropout): Dropout(p=0, inplace=False)                                                         
  (cls.predictions.transform.dense): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                      
  (cls.predictions.decoder): MixLinear(mixout=0.9, in_features=768, out_features=30522, bias=True)                            
)

So setattr doesn’t seem to be the way to do what I want here, which is to substitute the original layer and turn off the dropout in place. I confirmed this by doing something like getattr(model, 'bert.encoder.layer.0.attention.self.query') before the loop and it indeed says that an attribute by that name doesn’t exist.

How can I swap out the Linear layers and turn off the dropout in place? If it were just this one model I could try and brute-force it by going through model._modules but I’m trying to do this for several and it’s not really feasible to manually inspect the model structure every time. It also just seems like there must be a better way to do it than that!

I may have found a workaround! I won’t close this just in case someone notices an issue I don’t, or knows of a more straightforward way to do this.

import torch
from torch import nn
from transformers import AutoModelForMaskedLM

from copy import deepcopy

from module import MixLinear

model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

print(model)

def replace_layer(module):
    if isinstance(module, nn.Dropout):
        return nn.Dropout(0)
    elif isinstance(module, nn.Linear):
        target_state_dict   = deepcopy(module.state_dict())
        bias                = True if module.bias is not None else False
        new_module          = MixLinear(
                                module.in_features,
                                module.out_features,
                                bias,
                                target_state_dict['weight'],
                                0.9
                            )
        new_module.load_state_dict(target_state_dict)
        return new_module
    else:
        return module

def recursive_setattr(obj, attr, value):
    attr = attr.split('.', 1)
    if len(attr) == 1:
        setattr(obj, attr[0], value)
    else:
        recursive_setattr(getattr(obj, attr[0]), attr[1], value)

for name, module in tuple(model.named_modules()):
    if name:
        recursive_setattr(model, name, replace_layer(module))

print(model)

Now the output after the loop looks like this:

BertForMaskedLM(                                                                                                              
  (bert): BertModel(                                                                                                          
    (embeddings): BertEmbeddings(                                                                                             
      (word_embeddings): Embedding(30522, 768, padding_idx=0)                                                                 
      (position_embeddings): Embedding(512, 768)                                                                              
      (token_type_embeddings): Embedding(2, 768)                                                                              
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                      
      (dropout): Dropout(p=0, inplace=False)                                                                                  
    )                                                                                                                         
    (encoder): BertEncoder(                                                                                                   
      (layer): ModuleList(                                                                                                    
        (0): BertLayer(                                                                                                       
          (attention): BertAttention(                                                                                         
            (self): BertSelfAttention(                                                                                        
              (query): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (key): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                      
              (value): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (dropout): Dropout(p=0, inplace=False)                                                                          
            )                                                                                                                 
            (output): BertSelfOutput(                                                                                         
              (dense): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                              
              (dropout): Dropout(p=0, inplace=False)                                                                          
            )                                                                                                                 
          )                                                                                                                   
          (intermediate): BertIntermediate(                                                                                   
            (dense): MixLinear(mixout=0.9, in_features=768, out_features=3072, bias=True)                                     
            (intermediate_act_fn): GELUActivation()                                                                           
          )                                                                                                                   
          (output): BertOutput(                                                                                               
            (dense): MixLinear(mixout=0.9, in_features=3072, out_features=768, bias=True)                                     
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                
            (dropout): Dropout(p=0, inplace=False)                                                                            
          )                                                                                                                   
        )                                                                                                                     
        (1): BertLayer(                                                                                                       
          (attention): BertAttention(                                                                                         
            (self): BertSelfAttention(                                                                                        
              (query): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (key): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                      
              (value): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (dropout): Dropout(p=0, inplace=False)                                                                          
            )                                                                                                                 
            (output): BertSelfOutput(                                                                                         
              (dense): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                    
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                              
              (dropout): Dropout(p=0, inplace=False)                                                                          
            )                                                                                                                 
          )                                                                                                                   
          (intermediate): BertIntermediate(                                                                                   
            (dense): MixLinear(mixout=0.9, in_features=768, out_features=3072, bias=True)                                     
            (intermediate_act_fn): GELUActivation()                                                                           
          )                                                                                                                   
          (output): BertOutput(                                                                                               
            (dense): MixLinear(mixout=0.9, in_features=3072, out_features=768, bias=True)                                     
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                
            (dropout): Dropout(p=0, inplace=False)                                                                            
          )                                                                                                                   
        )                                                                                                                     
   
   [snip]
                                                                                                                    
  (cls): BertOnlyMLMHead(                                                                                                     
    (predictions): BertLMPredictionHead(                                                                                      
      (transform): BertPredictionHeadTransform(                                                                               
        (dense): MixLinear(mixout=0.9, in_features=768, out_features=768, bias=True)                                          
        (transform_act_fn): GELUActivation()                                                                                  
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)                                                    
      )                                                                                                                       
      (decoder): MixLinear(mixout=0.9, in_features=768, out_features=30522, bias=True)                                        
    )                                                                                                                         
  )                                                                                                                           
)  

This looks right to me! (Hopefully it actually is right!)