[Help needed] Extending Trainer for Meta learning

I want to implement MAML with Glue dataset with transformers. In my case, query and support set will come from the same dataset. I’ve read some work in meta learning from HF team (Wolf et al., 18).
Although I’ve implemented my training loop (with higher) (open for other methods as well), I am still looking for a correct reference implementation of MAML or Reptile to confirm. Currently my code inherits from Trainer. If anyone share a sample snippet that would perform MAML gradient updates, that’d be really helpful ?

So the MetaDataset wraps any GlueDataset to give a list containing all classes when meta_dataset[0] is called. So this will become, num_of_classes (N) way K shot example.

I’ve written this, which extends Trainer for MAML.

def train(self):

        self.create_optimizer_and_scheduler(
            int(
                len(self.train_dataloader)
                // self.args.gradient_accumulation_steps
                * self.args.num_train_epochs
            )
        )

        logger.info("***** Running training *****")

        self.global_step = 0
        self.epoch = 0

        eval_step = [2 ** i for i in range(1, 20)]
        inner_optimizer = torch.optim.SGD(
            self.model.parameters(), lr=self.args.step_size
        )
        self.model.train()

        tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")

        #  n_inner_iter = 5
        self.optimizer.zero_grad()
        query_dataloader = iter(self.train_dataloader)

        for batch_idx, meta_batch in enumerate(tqdm_iterator):
            target_batch = next(query_dataloader)
            outer_loss = 0.0
            # Loop through all classes
            for inputs, target_inputs in zip(meta_batch, target_batch):

                for k, v in inputs.items():
                    inputs[k] = v.to(self.args.device)
                    target_inputs[k] = v.to(self.args.device)

                with higher.innerloop_ctx(
                    self.model, inner_optimizer, copy_initial_weights=False
                ) as (fmodel, diffopt):

                    inner_loss = fmodel(**inputs)[0]
                    diffopt.step(inner_loss)
                    outer_loss += fmodel(**target_inputs)[0]

            self.global_step += 1
            self.optimizer.step()

            outer_loss.backward()

            if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.args.max_grad_norm
                )

            # Run evaluation on task list
            if self.global_step in eval_step:
                output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
                self.log(output.metrics)

                output_dir = os.path.join(
                    self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
                )
                self.save_model(output_dir)

I’m not completely sure how higher works. If someone can provide a minimal example with bare Pytorch, that’d be helpful.