Hello, I would like to visualize the attention map at every head, so I’ve found the following notebook on how to do it, with the DINO Model, but I would like to know if there exists any examples with Segformer, I’ve tried lots of codes but none works for me, and there’s no patch_size in the config of Segformer but patch_sizes so sometimes, i don’t know what to do ?
Here’s my segformerfinetuner class
"""Class of instance pytorch lightning
to Train & fine tune the model
"""
def __init__(
self,
id2label,
pretrained_model_name,
learning_rate,
metrics_interval=100,
):
super().__init__()
self.id2label = id2label
self.learning_rate = learning_rate
self.metrics_interval = metrics_interval
self.num_classes = len(id2label.keys())
self.label2id = {v: k for k, v in id2label.items()}
self.pretrained_model_name = pretrained_model_name
self.train_mean_iou = evaluate.load("mean_iou")
self.valid_mean_iou = evaluate.load("mean_iou")
self.test_mean_iou = evaluate.load("mean_iou")
self.save_hyperparameters()
self.model = SegformerForSemanticSegmentation.from_pretrained(
self.pretrained_model_name,
return_dict=False,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
ignore_mismatched_sizes=True,
)
def get_attention_map(self, images, masks=None):
"""Returning attention maps
it's by doing the output_attentions and the return dict
set to True and then fetch it from the outputs !
"""
outputs = self.model.forward(
pixel_values=images,
labels=masks,
output_attentions=True,
interpolate_pos_encoding=True,
)
attention_maps = outputs.attentions
return attention_maps
def forward(self, images, masks=None):
"""Forward the model takes images and mask"""
outputs = self.model(pixel_values=images, output_attentions=True)
return outputs
def training_step(self, batch, batch_idx):
"""Training step"""
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images=images, masks=masks)
predictions = outputs[0]
predictions = nn.functional.interpolate(
predictions,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
dloss = DiceLoss(mode="multiclass")
loss = dloss(predictions, masks)
predictions = predictions.argmax(dim=1)
self.train_mean_iou.add_batch(
predictions=predictions.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
metrics = self.train_mean_iou.compute(
num_labels=self.num_classes, ignore_index=255, reduce_labels=False
)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log(
"mean_iou", metrics["mean_iou"], on_step=True, on_epoch=True, prog_bar=True
)
self.log(
"mean_accuracy",
metrics["mean_accuracy"],
on_step=True,
on_epoch=True,
prog_bar=True,
)
return loss
def validation_step(self, batch, batch_idx):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
predictions = outputs[0]
predictions = nn.functional.interpolate(
predictions,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
dloss = DiceLoss(mode="multiclass")
loss = dloss(predictions, masks)
predictions = predictions.argmax(dim=1)
self.valid_mean_iou.add_batch(
predictions=predictions.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
self.log("valid_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def test_step(self, batch, batch_nb):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self.model(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.test_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
self.log("test_loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
opt = torch.optim.Adam(
[p for p in self.parameters() if p.requires_grad],
lr=self.learning_rate,
eps=1e-08,
weight_decay=0.1,
amsgrad=True,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer=opt, mode="min", patience=3, verbose=True
)
return {"optimizer": opt, "lr_schedulter": scheduler, "monitor": "valid_loss"}