I want to implement MAML with Glue dataset with transformers. In my case, query and support set will come from the same dataset. I’ve read some work in meta learning from HF team (Wolf et al., 18).
Although I’ve implemented my training loop (with higher
) (open for other methods as well), I am still looking for a correct reference implementation of MAML or Reptile to confirm. Currently my code inherits from Trainer
. If anyone share a sample snippet that would perform MAML gradient updates, that’d be really helpful ?
So the MetaDataset
wraps any GlueDataset
to give a list containing all classes when meta_dataset[0]
is called. So this will become, num_of_classes (N)
way K shot example.
I’ve written this, which extends Trainer
for MAML.
def train(self):
self.create_optimizer_and_scheduler(
int(
len(self.train_dataloader)
// self.args.gradient_accumulation_steps
* self.args.num_train_epochs
)
)
logger.info("***** Running training *****")
self.global_step = 0
self.epoch = 0
eval_step = [2 ** i for i in range(1, 20)]
inner_optimizer = torch.optim.SGD(
self.model.parameters(), lr=self.args.step_size
)
self.model.train()
tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")
# n_inner_iter = 5
self.optimizer.zero_grad()
query_dataloader = iter(self.train_dataloader)
for batch_idx, meta_batch in enumerate(tqdm_iterator):
target_batch = next(query_dataloader)
outer_loss = 0.0
# Loop through all classes
for inputs, target_inputs in zip(meta_batch, target_batch):
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
target_inputs[k] = v.to(self.args.device)
with higher.innerloop_ctx(
self.model, inner_optimizer, copy_initial_weights=False
) as (fmodel, diffopt):
inner_loss = fmodel(**inputs)[0]
diffopt.step(inner_loss)
outer_loss += fmodel(**target_inputs)[0]
self.global_step += 1
self.optimizer.step()
outer_loss.backward()
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
)
# Run evaluation on task list
if self.global_step in eval_step:
output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
self.log(output.metrics)
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
)
self.save_model(output_dir)
I’m not completely sure how higher
works. If someone can provide a minimal example with bare Pytorch, that’d be helpful.
Hey, @prajjwal1 did you implemented this?