I have made a multitask model for Translation and CLM tasks using mBART. Here is the class I have written to do so. I have used this notebook as reference for doing so. Here is my code.
class MultitaskModel(transformers.PreTrainedModel):
def __init__(self, encoder, taskmodels_dict):
super().__init__(transformers.PretrainedConfig())
self.encoder = encoder
self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
@classmethod
def create(cls, model_name, model_type_dict, model_config_dict):
shared_encoder = None
taskmodels_dict = {}
for task_name, model_type in model_type_dict.items():
model = model_type.from_pretrained(
model_name,
config=model_config_dict[task_name],
)
if shared_encoder is None:
shared_encoder = model.get_encoder()
else:
setattr(model, 'encoder', shared_encoder)
taskmodels_dict[task_name] = model
return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)
model_name = 'facebook/mbart-large-50'
multitask_model = MultitaskModel.create(
model_name=model_name,
model_type_dict={
'TRANSLATION':MBartForConditionalGeneration,
'CLM':MBartForCausalLM,
},
model_config_dict={
'TRANSLATION':AutoConfig.from_pretrained(model_name),
'CLM':AutoConfig.from_pretrained(model_name),
}
)
class StrIgnoreDevice(str):
def to(self, device):
return self
class DataLoaderWithTaskname:
def __init__(self, task_name, data_loader):
self.task_name = task_name
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
self.dataset = data_loader.dataset
def __len__(self):
return len(self.data_loader)
def __iter__(self):
for batch in self.data_loader:
batch["task_name"] = StrIgnoreDevice(self.task_name)
yield batch
class MultitaskDataloader:
def __init__(self, dataloader_dict):
self.dataloader_dict = dataloader_dict
self.num_batches_dict = {
task_name: len(dataloader)
for task_name, dataloader in self.dataloader_dict.items()
}
self.task_name_list = list(self.dataloader_dict)
self.dataset = [None] * sum(
len(dataloader.dataset)
for dataloader in self.dataloader_dict.values()
)
def __len__(self):
return sum(self.num_batches_dict.values())
def __iter__(self):
task_choice_list = []
for i, task_name in enumerate(self.task_name_list):
task_choice_list += [i] * self.num_batches_dict[task_name]
task_choice_list = np.array(task_choice_list)
np.random.shuffle(task_choice_list)
dataloader_iter_dict = {
task_name: iter(dataloader)
for task_name, dataloader in self.dataloader_dict.items()
}
for task_choice in task_choice_list:
task_name = self.task_name_list[task_choice]
yield next(dataloader_iter_dict[task_name])
class MultitaskTrainer(transformers.Trainer):
def get_single_train_dataloader(self, task_name, train_dataset):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
else:
train_sampler = (
RandomSampler(train_dataset)
if self.args.local_rank == -1
else DistributedSampler(train_dataset)
)
data_loader = DataLoaderWithTaskname(
task_name=task_name,
data_loader=DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=default_data_collator,
),
)
return data_loader
def get_train_dataloader(self):
return MultitaskDataloader({
task_name: self.get_single_train_dataloader(task_name, task_dataset)
for task_name, task_dataset in self.train_dataset.items()
})
trainer = MultitaskTrainer(
model=multitask_model,
args=transformers.TrainingArguments(
output_dir="/content/multi_model",
overwrite_output_dir=True,
learning_rate=1e-3,
do_train=True,
num_train_epochs=1
per_device_train_batch_size=2,
save_steps=5000,
),
data_collator=default_data_collator,
train_dataset=train_dataset,
)
Now, since the model is quite large and I am training on Colab, I need to save checkpoints and resume training from the next day. But when I use AutoModel class to load the model like
multitask_model = AutoModel.from_pretrained('.../checkpoint-9000')
I get “model_type” not defined in config.json file error. If I define model_type as “mbart”, it rightly is giving me a runtime error. How do I get around the issue?
How do I load the multitask_model from my checkpoints?