Detection Transformer (DETR) for text detection in documents

Hi,
i do currently some experiments on text detection with a transformer based model.
Do anyone have experience at this or recommendations ?
My idea is to train the DetrForObjectDetection on the COCOText-v2 dataset
COCOText-v2

i have tested some setups:

  • pretrained facebook/resnet-50 with num_queries=2000 (a good value for a A4 document page)
  • from scratch with efficentNet_b0 backbone from timm with backbone lr: 0.001 and lr: 0.01

but in all cases the loss and train loss stuck at ~1.7 after ~35 epochs with 2 val steps per epoch
another problem i have faiced is the COCOevaluator there seems to be a problem with numpy has no append at validation step:
in COCOeval:
problem:
self.eval_imgs[iou_type].append(eval_imgs)

one sample from my train dataloader looks like this:

# pixel_values 1 example
torch.Size([3, 640, 640]) 
# target for this example
{'boxes': tensor([[0.0810, 0.8323, 0.1621, 0.1356],
        [0.3031, 0.3070, 0.0367, 0.0088],
        [0.5304, 0.3418, 0.0349, 0.0102]]), 'class_labels': tensor([0, 0, 0]), 'image_id': tensor([367969]), 'area': tensor([5295.0200,  103.8200,  105.6000]), 'iscrowd': tensor([0, 0, 0]), 'orig_size': tensor([640, 556]), 'size': tensor([640, 556])}

so the data after Dataloader seems to be ok

some more code:
COCO_stuff:
adapted from:
COCOText
Pytorch COCO

Dataloader

def collate_fn(batch):
    """ process on every sample in batch
    """
    feature_extractor = DetrFeatureExtractor()
    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = dict()
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = labels
    return batch


class CocoTextDataset(Dataset):
    """MSCOCO Text V2 Dataset
    """
    def __init__(self, path, ann_file_name, image_folder_name, feature_extractor, is_train=True, data_limit=None):
        self.path = path
        self.annotation_path = os.path.join(path, ann_file_name)
        self.image_folder_path = os.path.join(path, image_folder_name)
        self.feature_extractor = feature_extractor
        self.data_limit = data_limit
        self.dataset_length = 0

        self.coco_text = COCO_Text(annotation_file=self.annotation_path)

        if is_train:
            print('Load Training Data')
            self.set_part = self.coco_text.train
        else:
            print('Load Validation Data')
            self.set_part = self.coco_text.val

        # create sets for train and validation
        self.cleaned_img_to_ann_ids = {k:v for k,v in self.coco_text.imgToAnns.items() if v and k in self.set_part}
        # sort out images and annotations, which are not readable or have uncorrect bound boxes
        self.ann_ids = list()
        self.image_ids = list()
        for entry_id in self.cleaned_img_to_ann_ids.values():
            annotations = self.coco_text.loadAnns(entry_id)
            allowed_ann_ids = list()
            allowed_image_ids = list()
            for annotation in annotations:
                if annotation['legibility'] == 'legible' and len(annotation['bbox']) == 4:
                    allowed_ann_ids.append(annotation['id'])
                    if annotation['image_id'] not in allowed_image_ids:
                        allowed_image_ids.append(annotation['image_id'])

            # if image has no annotations, skip it
            if allowed_image_ids and allowed_ann_ids:
                self.image_ids.append(allowed_image_ids)
                self.ann_ids.append(allowed_ann_ids)

        if self.data_limit:
            self.image_ids = self.image_ids[0:data_limit]
            self.ann_ids = self.ann_ids[0:data_limit]

        self.image_info = list()
        self.ann_info = list()
        for id in self.image_ids:
            info = self.coco_text.loadImgs(id)
            self.image_info.append(info)
        for id in self.ann_ids:
            info = self.coco_text.loadAnns(id)
            self.ann_info.append(info)

        if len(self.image_info) == len(self.ann_info):
            print('Dataset created sucessfully')
            self.dataset_length = len(self.image_info)
        else:
            print(f'Error: Number of images and annotations do not match. {len(self.image_info)} images and {len(self.ann_info)} annotations')
            sys.exit(0)

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_file = self.image_info[index]
        annotations = self.ann_info[index]
        image_path = os.path.join(self.image_folder_path, image_file[0]['file_name'])
        image = Image.open(image_path).convert("RGB")

        target = {'image_id': image_id[0], 'annotations': annotations}
        encoding = self.feature_extractor(images=image, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension
        return pixel_values, target


class COCODatasetLoader(pl.LightningDataModule):
    def __init__(self, path, ann_file_name, image_folder_name, feature_extractor, batch_size, worker, collator, data_limit=None):
        super().__init__()
        self.path = path
        self.ann_file_name = ann_file_name
        self.image_folder_name = image_folder_name
        self.feature_extractor = feature_extractor
        self.batch_size = batch_size
        self.worker = worker
        self.collator = collator
        self.data_limit = data_limit
        print(f'Data Limit is set to : {self.data_limit}')

    def setup(self, stage=None):
        self.train_dataset = CocoTextDataset(self.path, self.ann_file_name, self.image_folder_name, self.feature_extractor, is_train=True, data_limit=self.data_limit)
        print(f'# of training samples: {self.train_dataset.dataset_length}')

        self.val_dataset = CocoTextDataset(self.path, self.ann_file_name, self.image_folder_name, self.feature_extractor, is_train=False, data_limit=self.data_limit)
        print(f'# of validation samples: {self.val_dataset.dataset_length}')

    def visualize_example(self, index):
        print(f'Visualize Example: {index}')
        file_name = self.train_dataset.coco_text.loadImgs(self.train_dataset.image_ids[index])[0]['file_name']
        path = os.path.join(self.train_dataset.image_folder_path, file_name)
        annotations = self.train_dataset.coco_text.loadAnns(self.train_dataset.ann_ids[index])
        print(f'{len(annotations)} boxes in image detected')

        image = Image.open(path).convert("RGB")
        draw = ImageDraw.Draw(image, "RGBA")

        for annotation in annotations:
            box = annotation['bbox']
            x,y,w,h = tuple(box)
            draw.rectangle((x,y,x+w,y+h), outline='red', width=1)

        image.show()

    def get_val_coco_text_dataset(self):
        return self.val_dataset.coco_text

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.worker, pin_memory=True, collate_fn=self.collator)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.worker, pin_memory=True, collate_fn=self.collator)

Model:

class TextDetectionModel(pl.LightningModule):

    def __init__(self, lr, id2label, feature_extractor, coco_evaluator, sync):
        super().__init__()
        self.save_hyperparameters()
        self.sync_dist = sync
        self.lr = lr
        self.id2label = id2label
        self.feature_extractor = feature_extractor
        self.coco_evaluator = coco_evaluator
        self.num_classes = len(id2label)
        self.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", num_queries=2000, encoder_layerdrop=0.2, decoder_layerdrop=0.2,
                                                            num_labels=self.num_classes, ignore_mismatched_sizes=True, return_dict=True)

    def forward(self, pixel_values, pixel_mask=None, labels=None):
        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels, return_dict=True)
        return outputs.loss, outputs.loss_dict, outputs.logits, outputs.pred_boxes

    def training_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        loss = outputs[0]
        loss_dict = outputs[1]

        self.log("train_loss", loss.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
        for k,v in loss_dict.items():
            self.log("train_" + k, v.item())

        return loss

    def validation_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
        bboxes = [entry['boxes'] for entry in labels]


        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        loss = outputs[0]
        loss_dict = outputs[1]
        logits = outputs[2]
    #    pred_boxes = outputs[3]

        # compute averaged probability of each bbox
        proba = torch.stack([x for x in logits.softmax(-1)[0, :, :-1]]).mean()

        # compute COCO Output for each image
    #    orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
    #    results = self.feature_extractor.post_process(outputs, orig_target_sizes) # convert outputs of model to COCO api
    #    res = {target['image_id'].item(): output for target, output in zip(labels, results)}

    # Coco Eval is broken currently
    #    self.coco_evaluator.update(res)

        self.log("val_loss", loss.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
        self.log("val_bbox_proba", proba.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
        for k,v in loss_dict.items():
            self.log("val_" + k, v.item())

        return loss

    #def validation_epoch_end(self, outputs):
    #    self.coco_evaluator.synchronize_between_processes()
    #    self.coco_evaluator.accumulate()
    #    self.coco_evaluator.summarize()

    def predict_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]

        outputs = self.model(pixel_values=pixel_values)
        logits = outputs[2]
        pred_boxes = outputs[3]

        probas = logits.softmax(-1)[0, :, :-1]
        return {'probas': probas, 'pred_boxes': pred_boxes}

    def configure_optimizers(self):
        param_dicts = [
              {"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
              {
                  "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                  "lr": 1e-5, # this lr is used for backbone parameters
              },
        ]
        optimizer = AdamW(param_dicts, lr=self.lr, weight_decay=1e-4)
        scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)

Trainer

import argparse
import os
import warnings
import time

import numpy as np
import onnx
import pytorch_lightning as pl
import torch
from onnxruntime.quantization import quantize_qat
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor, ModelCheckpoint)
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import DetrFeatureExtractor

from coco_tools.coco_torch_evaluator import CocoEvaluator
from dataloader import COCODatasetLoader, collate_fn
from model import TextDetectionModel

def __check_for_boolean_value(val):
    """argparse helper function
    """
    if val.lower() == "true":
        return True
    else:
        return False

if __name__ == '__main__':

    warnings.filterwarnings("ignore")

    pl.seed_everything(42, workers=True)

    print('annotations file and image folder have to be in the same parent folder')

    parser = argparse.ArgumentParser(description='Text Detection Trainer')
    parser.add_argument("--path", help='path to generated images', type=str, required=False, default='/COCOText-v2') #set to true
    parser.add_argument("--ann_file_name", help='name of annotations file', type=str, required=False, default='cocotext.v2.json')
    parser.add_argument("--image_folder_name", help='name of image folder', type=str, required=False, default='train2014')
    parser.add_argument("--epochs", help='how many epochs to train the model',type=int, required=False, default=250)
    parser.add_argument("--batch_size", help='how big are a batch',type=int, required=False, default=8)
    parser.add_argument("--data_limit", help='set a fixed data limit',type=int, required=False, default=0)
    parser.add_argument("--worker", help='how many threads for the Dataloader',type=int, required=False, default=0)
    parser.add_argument("--learning_rate", help='the learning rate for the optimizer',type=float, required=False, default=1e-4)
    parser.add_argument("--gradient_clip", help='float for gradient clipping',type=float, required=False, default=0.1)
    parser.add_argument("--visualize_random_example", help='if true show an example from train set',type=__check_for_boolean_value, required=False, default=False)

    args = parser.parse_args()
    path = args.path
    ann_file_name = args.ann_file_name
    image_folder_name = args.image_folder_name
    epochs = args.epochs
    batch_size = args.batch_size
    data_limit = args.data_limit
    worker = args.worker
    learning_rate = args.learning_rate
    gradient_clip = args.gradient_clip
    visualize_random_example = args.visualize_random_example

    if data_limit == 0:
        data_limit = None

    # resource handling
    if torch.cuda.device_count() >= 1:
        batch_size = int(batch_size / torch.cuda.device_count())
        accelerator = 'ddp'
        sync = True
    else:
        accelerator = None
        sync = False

    ### Data Part
    os.makedirs('text_detection_model_files', exist_ok=True)

    feature_extractor = DetrFeatureExtractor(format="coco_detection", do_resize=False, do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])
    feat_extractor_to_save = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", do_resize=True, size=600)
    feat_extractor_to_save.save_pretrained('text_detection_model_files/transformer_model/')
    print('feature extractor saved succesful')


    data_module = COCODatasetLoader(path=path,
                                    ann_file_name=ann_file_name,
                                    image_folder_name=image_folder_name,
                                    feature_extractor=feature_extractor,
                                    batch_size=batch_size,
                                    worker=worker,
                                    collator=collate_fn,
                                    data_limit=data_limit)
    data_module.setup()
    if visualize_random_example:
        index = np.random.choice(len(data_module.train_dataset))
        data_module.visualize_example(index)

    train = data_module.train_dataloader()
    val = data_module.val_dataloader()
    coco_val_dataset = data_module.get_val_coco_text_dataset()
    coco_evaluator = CocoEvaluator(coco_val_dataset, ['bbox'])
    print('Coco Evaluator created')

    ### Model Part
    id2label = {0: 'Text'} # we have only one class to detect: Text
    text_detection_model = TextDetectionModel(lr=learning_rate, id2label=id2label, feature_extractor=feature_extractor, coco_evaluator=coco_evaluator, sync=sync)

    ### Callback Part

    checkpoint_callback = ModelCheckpoint(
        dirpath="text_detection_model_files/checkpoints",
        filename="best-checkpoint",
        save_top_k=1,
        verbose=True,
        monitor="val_loss",
        mode="min"
    )

    logger = TensorBoardLogger(save_dir="text_detection_model_files/Lightning_logs", name="Text_Detection")

    early_stopping_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.001,
        patience=15,
        check_finite=True,
        verbose=True
    )

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    ### Training Part

    trainer = pl.Trainer(logger=logger,
                         weights_summary="full",
                         # only if gpu mem is overheaded -> needs much more train time
                         benchmark=True,
                         move_metrics_to_cpu=False,
                         val_check_interval=0.5,
                         gradient_clip_val=gradient_clip,  # set to 0.5 to avoid exploding gradients
                         stochastic_weight_avg=True,
                         callbacks=[
                             checkpoint_callback,
                             early_stopping_callback,
                             lr_monitor
                             ],
                         max_epochs=epochs,
                         gpus=torch.cuda.device_count(),
                         accelerator=accelerator,
                         precision=32, # dont change for model
                         accumulate_grad_batches=1,  # optimizer step after every n batches  -> better gpu mem usage / model specific
                         progress_bar_refresh_rate=20,
                        # profiler='pytorch', # only for debug
                         )

    trainer.fit(text_detection_model, train, val)

    time.sleep(2) # short delay
    trained_model = text_detection_model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    trained_model.eval()
    trained_model.freeze()

    ### Saving Part

    # ----------------------------------
    # PyTorch Model - full
    # ----------------------------------
    try:
        torch.save(trained_model, "text_detection_model_files/torch_text_detection_model.pt")
        print('Torch model saved successful')
    except Exception as e:
        print('Cannot export as PyTorch Format -- Error : ' + str(e))

    # ----------------------------------
    # PyTorch Model - state dict
    # ----------------------------------
    try:
        torch.save(trained_model.state_dict(), "text_detection_model_files/torch_text_detection_model_state_dict.pt")
        print('Torch model state dict saved successful')
    except Exception as e:
        print('Cannot export as PyTorch Format with state dict -- Error : ' + str(e))

    # ----------------------------------
    # onnx
    # ----------------------------------
    try:
        input_batch = next(iter(val))
        input_sample = {
            "pixel_values": input_batch["pixel_values"][0].unsqueeze(0),
        }
        values = input_sample['pixel_values']
        file_path = "text_detection_model_files/torch_text_detection_model.onnx"
        torch.onnx.export(trained_model, values, file_path,
                              input_names=['pixel_values'],
                              output_names=['logits', 'pred_boxes'],
                              dynamic_axes={'pixel_values': {0: 'batch_size', 1: 'channels', 2: 'width', 3: 'height'},
                                            'logits': {0: 'batch_size'}, 'pred_boxes': {0: 'batch_size'}},
                              export_params=True, opset_version=11,
                              enable_onnx_checker=True, verbose=False)
        print('Onnx model saved successful')
        print('Start model quantization')
        model_quant = "text_detection_model_files/torch_text_detection_model.quant.onnx"
        quantized_model = quantize_qat(file_path, model_quant)
        print('Quantization succesfull')
    except Exception as e:
        print('Cannot export as ONNX Format -- Error : ' + str(e))

    # Predictions
    model = text_detection_model.load_from_checkpoint(checkpoint_path=trainer.checkpoint_callback.best_model_path)
    preds = trainer.predict(model, val, return_predictions=True)
    print(preds)

@nielsr do you have any idea or recommendations ? ^^