Loading custom class model instance saved using accelerate library fails

Hi,

I trained a model defined using custom class on 8-GPU setup using Accelerate library. The model trains on multi-GPU setup and is saved successfully.

Saving model:

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.eval()
accelerator.save(unwrapped_model.state_dict(), save_model)

Now I want to fine-tune previously saved model on different dataset and this is how I load the model:

model.load_state_dict(torch.load(model_path_level_1))

And then pass it on to accelerate.prepare:

    accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) 
model, optimizer, train_dataloader, validation_dataloader = accelerator.prepare(model, optimizer, train_dataloader, validation_dataloader)

This script launches 8 processes on GPU:0 and the code fails with Out of memory issues after sometime. After launching 8 processes, one process each is launched on other 7 GPUs before script crashes.

Is this a problem with loading custom class models?

Thank you.

Is this all in the same script or in separate scripts? You nay have the old model taking space in the memory in the first case.

First code snippet in different script and completes the run.
Second and third code snippets are part of script that runs after I have a trained and saved a model.
Since both scripts are run one after other, the first model is not in memory at all during second script run

This is the custom class:

class BERTMLM(nn.Module):
    def __init__(self, language_model, topics, input_channels=64, params=default_bert_params,
                 save_pretrained_model=False, threshold=None,
                 lr_method: str = 'linear'):
        super(BERTMLM, self).__init__()
        if os.path.exists(DIR + '/../' + language_model):
            # use manually downloaded (with model.save_pretrained())
            lang_model = DIR + '/../' + language_model
        else:
            lang_model = language_model
        if language_model == 'bert-base-multilingual-cased':
            self.bert_model = BertModel.from_pretrained(lang_model)
            self.output_vec_size = 768
        elif language_model == 'bert-base-cased':
            self.bert_model = BertModel.from_pretrained(lang_model)
            self.output_vec_size = 768
        elif language_model == 'bert-base-german-cased':
            self.bert_model = BertModel.from_pretrained(lang_model)
            self.output_vec_size = 768
        elif language_model == 'xlm-roberta-base':
            self.bert_model = XLMRobertaModel.from_pretrained(lang_model)
            self.bert_config = XLMRobertaConfig.from_pretrained(lang_model)
        elif language_model == 'xlm-roberta-large':
            self.bert_model = XLMRobertaModel.from_pretrained(lang_model)
            self.bert_config = XLMRobertaConfig.from_pretrained(lang_model)
            self.output_vec_size = 1024
        else:
            raise NotImplementedError()
        if not os.path.exists(DIR + '/../' + language_model) and save_pretrained_model:
            self.bert_model.save_pretrained(DIR + '/../' + language_model)

        self.threshold = threshold

        self.lr = params['lr']
        self.max_epochs_without_loss_reduction = params['max_epochs_without_loss_reduction']
        self.epochs = params['epochs']
        self.params = params
        self.lr_method = lr_method
        self.batch_size = params['batch_size']
        self.vocab_size = self.bert_config.vocab_size

        # Final Layer - LM head to predict Masked token
        self.fc_layer = nn.Linear(self.bert_config.hidden_size, self.bert_config.hidden_size)
        self.relu = nn.ReLU()
        self.layer_norm = nn.LayerNorm(self.bert_config.hidden_size, eps=1e-8)
        self.decoder = nn.Linear(self.bert_config.hidden_size, self.bert_config.vocab_size)

    def forward(self, b_input_ids, b_input_mask):
        out = self.bert_model(b_input_ids, attention_mask=b_input_mask, return_dict=False)[0]

        # Final Layer - LM head
        out = self.fc_layer(out)
        out = self.relu(out)
        out = self.layer_norm(out)
        out = self.decoder(out)

        return out