How to visualize the attention map of my Segformer model?

Hello, I would like to visualize the attention map at every head, so I’ve found the following notebook on how to do it, with the DINO Model, but I would like to know if there exists any examples with Segformer, I’ve tried lots of codes but none works for me, and there’s no patch_size in the config of Segformer but patch_sizes so sometimes, i don’t know what to do ?
Here’s my segformerfinetuner class

    """Class of instance pytorch lightning
    to Train & fine tune the model
    """

    def __init__(
        self,
        id2label,
        pretrained_model_name,
        learning_rate,
        metrics_interval=100,
    ):
        super().__init__()

        self.id2label = id2label
        self.learning_rate = learning_rate
        self.metrics_interval = metrics_interval
        self.num_classes = len(id2label.keys())
        self.label2id = {v: k for k, v in id2label.items()}
        self.pretrained_model_name = pretrained_model_name

        self.train_mean_iou = evaluate.load("mean_iou")
        self.valid_mean_iou = evaluate.load("mean_iou")
        self.test_mean_iou = evaluate.load("mean_iou")

        self.save_hyperparameters()

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            self.pretrained_model_name,
            return_dict=False,
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )

    def get_attention_map(self, images, masks=None):
        """Returning attention maps
        it's by doing the output_attentions and the return dict
        set to True and then fetch it from the outputs !
        """
        outputs = self.model.forward(
            pixel_values=images,
            labels=masks,
            output_attentions=True,
            interpolate_pos_encoding=True,
        )
        attention_maps = outputs.attentions

        return attention_maps

    def forward(self, images, masks=None):
        """Forward the model takes images and mask"""
        outputs = self.model(pixel_values=images, output_attentions=True)
        return outputs

    def training_step(self, batch, batch_idx):
        """Training step"""
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images=images, masks=masks)
        predictions = outputs[0]

        predictions = nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )

        dloss = DiceLoss(mode="multiclass")
        loss = dloss(predictions, masks)

        predictions = predictions.argmax(dim=1)

        self.train_mean_iou.add_batch(
            predictions=predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        metrics = self.train_mean_iou.compute(
            num_labels=self.num_classes, ignore_index=255, reduce_labels=False
        )

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "mean_iou", metrics["mean_iou"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "mean_accuracy",
            metrics["mean_accuracy"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch["pixel_values"], batch["labels"]

        outputs = self(images, masks)
        predictions = outputs[0]
        predictions = nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        dloss = DiceLoss(mode="multiclass")

        loss = dloss(predictions, masks)

        predictions = predictions.argmax(dim=1)

        self.valid_mean_iou.add_batch(
            predictions=predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        self.log("valid_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_nb):
        images, masks = batch["pixel_values"], batch["labels"]

        outputs = self.model(images, masks)

        loss, logits = outputs[0], outputs[1]

        upsampled_logits = nn.functional.interpolate(
            logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
        )

        predicted = upsampled_logits.argmax(dim=1)

        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        self.log("test_loss", loss, on_step=True, prog_bar=True)

        return loss


    def configure_optimizers(self):
        opt = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.learning_rate,
            eps=1e-08,
            weight_decay=0.1,
            amsgrad=True,
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=opt, mode="min", patience=3, verbose=True
        )

        return {"optimizer": opt, "lr_schedulter": scheduler, "monitor": "valid_loss"}