How to load a custom Multitask model from Checkpoint?

I have made a multitask model for Translation and CLM tasks using mBART. Here is the class I have written to do so. I have used this notebook as reference for doing so. Here is my code.

class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = model.get_encoder()
            else:
                setattr(model, 'encoder', shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

model_name = 'facebook/mbart-large-50'

multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        'TRANSLATION':MBartForConditionalGeneration,
        'CLM':MBartForCausalLM,
    },
    model_config_dict={
        'TRANSLATION':AutoConfig.from_pretrained(model_name),
        'CLM':AutoConfig.from_pretrained(model_name),
    }
)

class StrIgnoreDevice(str):
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):

    def get_single_train_dataloader(self, task_name, train_dataset):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        else:
            train_sampler = (
                RandomSampler(train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(train_dataset)
            )

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=default_data_collator,
            ),
        )

        return data_loader

    def get_train_dataloader(self):
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir="/content/multi_model",
        overwrite_output_dir=True,
        learning_rate=1e-3,
        do_train=True,
        num_train_epochs=1
        per_device_train_batch_size=2,  
        save_steps=5000,
    ),
    data_collator=default_data_collator,
    train_dataset=train_dataset,
)

Now, since the model is quite large and I am training on Colab, I need to save checkpoints and resume training from the next day. But when I use AutoModel class to load the model like

multitask_model = AutoModel.from_pretrained('.../checkpoint-9000')

I get “model_type” not defined in config.json file error. If I define model_type as “mbart”, it rightly is giving me a runtime error. How do I get around the issue?

How do I load the multitask_model from my checkpoints?

1 Like