Hi,
i do currently some experiments on text detection with a transformer based model.
Do anyone have experience at this or recommendations ?
My idea is to train the DetrForObjectDetection on the COCOText-v2 dataset
COCOText-v2
i have tested some setups:
- pretrained facebook/resnet-50 with num_queries=2000 (a good value for a A4 document page)
- from scratch with efficentNet_b0 backbone from timm with backbone lr: 0.001 and lr: 0.01
but in all cases the loss and train loss stuck at ~1.7 after ~35 epochs with 2 val steps per epoch
another problem i have faiced is the COCOevaluator there seems to be a problem with numpy has no append at validation step:
in COCOeval:
problem:
self.eval_imgs[iou_type].append(eval_imgs)
one sample from my train dataloader looks like this:
# pixel_values 1 example
torch.Size([3, 640, 640])
# target for this example
{'boxes': tensor([[0.0810, 0.8323, 0.1621, 0.1356],
[0.3031, 0.3070, 0.0367, 0.0088],
[0.5304, 0.3418, 0.0349, 0.0102]]), 'class_labels': tensor([0, 0, 0]), 'image_id': tensor([367969]), 'area': tensor([5295.0200, 103.8200, 105.6000]), 'iscrowd': tensor([0, 0, 0]), 'orig_size': tensor([640, 556]), 'size': tensor([640, 556])}
so the data after Dataloader seems to be ok
some more code:
COCO_stuff:
adapted from:
COCOText
Pytorch COCO
Dataloader
def collate_fn(batch):
""" process on every sample in batch
"""
feature_extractor = DetrFeatureExtractor()
pixel_values = [item[0] for item in batch]
encoding = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
labels = [item[1] for item in batch]
batch = dict()
batch['pixel_values'] = encoding['pixel_values']
batch['pixel_mask'] = encoding['pixel_mask']
batch['labels'] = labels
return batch
class CocoTextDataset(Dataset):
"""MSCOCO Text V2 Dataset
"""
def __init__(self, path, ann_file_name, image_folder_name, feature_extractor, is_train=True, data_limit=None):
self.path = path
self.annotation_path = os.path.join(path, ann_file_name)
self.image_folder_path = os.path.join(path, image_folder_name)
self.feature_extractor = feature_extractor
self.data_limit = data_limit
self.dataset_length = 0
self.coco_text = COCO_Text(annotation_file=self.annotation_path)
if is_train:
print('Load Training Data')
self.set_part = self.coco_text.train
else:
print('Load Validation Data')
self.set_part = self.coco_text.val
# create sets for train and validation
self.cleaned_img_to_ann_ids = {k:v for k,v in self.coco_text.imgToAnns.items() if v and k in self.set_part}
# sort out images and annotations, which are not readable or have uncorrect bound boxes
self.ann_ids = list()
self.image_ids = list()
for entry_id in self.cleaned_img_to_ann_ids.values():
annotations = self.coco_text.loadAnns(entry_id)
allowed_ann_ids = list()
allowed_image_ids = list()
for annotation in annotations:
if annotation['legibility'] == 'legible' and len(annotation['bbox']) == 4:
allowed_ann_ids.append(annotation['id'])
if annotation['image_id'] not in allowed_image_ids:
allowed_image_ids.append(annotation['image_id'])
# if image has no annotations, skip it
if allowed_image_ids and allowed_ann_ids:
self.image_ids.append(allowed_image_ids)
self.ann_ids.append(allowed_ann_ids)
if self.data_limit:
self.image_ids = self.image_ids[0:data_limit]
self.ann_ids = self.ann_ids[0:data_limit]
self.image_info = list()
self.ann_info = list()
for id in self.image_ids:
info = self.coco_text.loadImgs(id)
self.image_info.append(info)
for id in self.ann_ids:
info = self.coco_text.loadAnns(id)
self.ann_info.append(info)
if len(self.image_info) == len(self.ann_info):
print('Dataset created sucessfully')
self.dataset_length = len(self.image_info)
else:
print(f'Error: Number of images and annotations do not match. {len(self.image_info)} images and {len(self.ann_info)} annotations')
sys.exit(0)
def __len__(self):
return self.dataset_length
def __getitem__(self, index):
image_id = self.image_ids[index]
image_file = self.image_info[index]
annotations = self.ann_info[index]
image_path = os.path.join(self.image_folder_path, image_file[0]['file_name'])
image = Image.open(image_path).convert("RGB")
target = {'image_id': image_id[0], 'annotations': annotations}
encoding = self.feature_extractor(images=image, annotations=target, return_tensors="pt")
pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
target = encoding["labels"][0] # remove batch dimension
return pixel_values, target
class COCODatasetLoader(pl.LightningDataModule):
def __init__(self, path, ann_file_name, image_folder_name, feature_extractor, batch_size, worker, collator, data_limit=None):
super().__init__()
self.path = path
self.ann_file_name = ann_file_name
self.image_folder_name = image_folder_name
self.feature_extractor = feature_extractor
self.batch_size = batch_size
self.worker = worker
self.collator = collator
self.data_limit = data_limit
print(f'Data Limit is set to : {self.data_limit}')
def setup(self, stage=None):
self.train_dataset = CocoTextDataset(self.path, self.ann_file_name, self.image_folder_name, self.feature_extractor, is_train=True, data_limit=self.data_limit)
print(f'# of training samples: {self.train_dataset.dataset_length}')
self.val_dataset = CocoTextDataset(self.path, self.ann_file_name, self.image_folder_name, self.feature_extractor, is_train=False, data_limit=self.data_limit)
print(f'# of validation samples: {self.val_dataset.dataset_length}')
def visualize_example(self, index):
print(f'Visualize Example: {index}')
file_name = self.train_dataset.coco_text.loadImgs(self.train_dataset.image_ids[index])[0]['file_name']
path = os.path.join(self.train_dataset.image_folder_path, file_name)
annotations = self.train_dataset.coco_text.loadAnns(self.train_dataset.ann_ids[index])
print(f'{len(annotations)} boxes in image detected')
image = Image.open(path).convert("RGB")
draw = ImageDraw.Draw(image, "RGBA")
for annotation in annotations:
box = annotation['bbox']
x,y,w,h = tuple(box)
draw.rectangle((x,y,x+w,y+h), outline='red', width=1)
image.show()
def get_val_coco_text_dataset(self):
return self.val_dataset.coco_text
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.worker, pin_memory=True, collate_fn=self.collator)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.worker, pin_memory=True, collate_fn=self.collator)
Model:
class TextDetectionModel(pl.LightningModule):
def __init__(self, lr, id2label, feature_extractor, coco_evaluator, sync):
super().__init__()
self.save_hyperparameters()
self.sync_dist = sync
self.lr = lr
self.id2label = id2label
self.feature_extractor = feature_extractor
self.coco_evaluator = coco_evaluator
self.num_classes = len(id2label)
self.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", num_queries=2000, encoder_layerdrop=0.2, decoder_layerdrop=0.2,
num_labels=self.num_classes, ignore_mismatched_sizes=True, return_dict=True)
def forward(self, pixel_values, pixel_mask=None, labels=None):
outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels, return_dict=True)
return outputs.loss, outputs.loss_dict, outputs.logits, outputs.pred_boxes
def training_step(self, batch, batch_idx):
pixel_values = batch["pixel_values"]
pixel_mask = batch["pixel_mask"]
labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
loss = outputs[0]
loss_dict = outputs[1]
self.log("train_loss", loss.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
for k,v in loss_dict.items():
self.log("train_" + k, v.item())
return loss
def validation_step(self, batch, batch_idx):
pixel_values = batch["pixel_values"]
pixel_mask = batch["pixel_mask"]
labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
bboxes = [entry['boxes'] for entry in labels]
outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
loss = outputs[0]
loss_dict = outputs[1]
logits = outputs[2]
# pred_boxes = outputs[3]
# compute averaged probability of each bbox
proba = torch.stack([x for x in logits.softmax(-1)[0, :, :-1]]).mean()
# compute COCO Output for each image
# orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
# results = self.feature_extractor.post_process(outputs, orig_target_sizes) # convert outputs of model to COCO api
# res = {target['image_id'].item(): output for target, output in zip(labels, results)}
# Coco Eval is broken currently
# self.coco_evaluator.update(res)
self.log("val_loss", loss.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
self.log("val_bbox_proba", proba.detach(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
for k,v in loss_dict.items():
self.log("val_" + k, v.item())
return loss
#def validation_epoch_end(self, outputs):
# self.coco_evaluator.synchronize_between_processes()
# self.coco_evaluator.accumulate()
# self.coco_evaluator.summarize()
def predict_step(self, batch, batch_idx):
pixel_values = batch["pixel_values"]
outputs = self.model(pixel_values=pixel_values)
logits = outputs[2]
pred_boxes = outputs[3]
probas = logits.softmax(-1)[0, :, :-1]
return {'probas': probas, 'pred_boxes': pred_boxes}
def configure_optimizers(self):
param_dicts = [
{"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
"lr": 1e-5, # this lr is used for backbone parameters
},
]
optimizer = AdamW(param_dicts, lr=self.lr, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
optimizer.zero_grad(set_to_none=True)
Trainer
import argparse
import os
import warnings
import time
import numpy as np
import onnx
import pytorch_lightning as pl
import torch
from onnxruntime.quantization import quantize_qat
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor, ModelCheckpoint)
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import DetrFeatureExtractor
from coco_tools.coco_torch_evaluator import CocoEvaluator
from dataloader import COCODatasetLoader, collate_fn
from model import TextDetectionModel
def __check_for_boolean_value(val):
"""argparse helper function
"""
if val.lower() == "true":
return True
else:
return False
if __name__ == '__main__':
warnings.filterwarnings("ignore")
pl.seed_everything(42, workers=True)
print('annotations file and image folder have to be in the same parent folder')
parser = argparse.ArgumentParser(description='Text Detection Trainer')
parser.add_argument("--path", help='path to generated images', type=str, required=False, default='/COCOText-v2') #set to true
parser.add_argument("--ann_file_name", help='name of annotations file', type=str, required=False, default='cocotext.v2.json')
parser.add_argument("--image_folder_name", help='name of image folder', type=str, required=False, default='train2014')
parser.add_argument("--epochs", help='how many epochs to train the model',type=int, required=False, default=250)
parser.add_argument("--batch_size", help='how big are a batch',type=int, required=False, default=8)
parser.add_argument("--data_limit", help='set a fixed data limit',type=int, required=False, default=0)
parser.add_argument("--worker", help='how many threads for the Dataloader',type=int, required=False, default=0)
parser.add_argument("--learning_rate", help='the learning rate for the optimizer',type=float, required=False, default=1e-4)
parser.add_argument("--gradient_clip", help='float for gradient clipping',type=float, required=False, default=0.1)
parser.add_argument("--visualize_random_example", help='if true show an example from train set',type=__check_for_boolean_value, required=False, default=False)
args = parser.parse_args()
path = args.path
ann_file_name = args.ann_file_name
image_folder_name = args.image_folder_name
epochs = args.epochs
batch_size = args.batch_size
data_limit = args.data_limit
worker = args.worker
learning_rate = args.learning_rate
gradient_clip = args.gradient_clip
visualize_random_example = args.visualize_random_example
if data_limit == 0:
data_limit = None
# resource handling
if torch.cuda.device_count() >= 1:
batch_size = int(batch_size / torch.cuda.device_count())
accelerator = 'ddp'
sync = True
else:
accelerator = None
sync = False
### Data Part
os.makedirs('text_detection_model_files', exist_ok=True)
feature_extractor = DetrFeatureExtractor(format="coco_detection", do_resize=False, do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])
feat_extractor_to_save = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", do_resize=True, size=600)
feat_extractor_to_save.save_pretrained('text_detection_model_files/transformer_model/')
print('feature extractor saved succesful')
data_module = COCODatasetLoader(path=path,
ann_file_name=ann_file_name,
image_folder_name=image_folder_name,
feature_extractor=feature_extractor,
batch_size=batch_size,
worker=worker,
collator=collate_fn,
data_limit=data_limit)
data_module.setup()
if visualize_random_example:
index = np.random.choice(len(data_module.train_dataset))
data_module.visualize_example(index)
train = data_module.train_dataloader()
val = data_module.val_dataloader()
coco_val_dataset = data_module.get_val_coco_text_dataset()
coco_evaluator = CocoEvaluator(coco_val_dataset, ['bbox'])
print('Coco Evaluator created')
### Model Part
id2label = {0: 'Text'} # we have only one class to detect: Text
text_detection_model = TextDetectionModel(lr=learning_rate, id2label=id2label, feature_extractor=feature_extractor, coco_evaluator=coco_evaluator, sync=sync)
### Callback Part
checkpoint_callback = ModelCheckpoint(
dirpath="text_detection_model_files/checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min"
)
logger = TensorBoardLogger(save_dir="text_detection_model_files/Lightning_logs", name="Text_Detection")
early_stopping_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.001,
patience=15,
check_finite=True,
verbose=True
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
### Training Part
trainer = pl.Trainer(logger=logger,
weights_summary="full",
# only if gpu mem is overheaded -> needs much more train time
benchmark=True,
move_metrics_to_cpu=False,
val_check_interval=0.5,
gradient_clip_val=gradient_clip, # set to 0.5 to avoid exploding gradients
stochastic_weight_avg=True,
callbacks=[
checkpoint_callback,
early_stopping_callback,
lr_monitor
],
max_epochs=epochs,
gpus=torch.cuda.device_count(),
accelerator=accelerator,
precision=32, # dont change for model
accumulate_grad_batches=1, # optimizer step after every n batches -> better gpu mem usage / model specific
progress_bar_refresh_rate=20,
# profiler='pytorch', # only for debug
)
trainer.fit(text_detection_model, train, val)
time.sleep(2) # short delay
trained_model = text_detection_model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trained_model.eval()
trained_model.freeze()
### Saving Part
# ----------------------------------
# PyTorch Model - full
# ----------------------------------
try:
torch.save(trained_model, "text_detection_model_files/torch_text_detection_model.pt")
print('Torch model saved successful')
except Exception as e:
print('Cannot export as PyTorch Format -- Error : ' + str(e))
# ----------------------------------
# PyTorch Model - state dict
# ----------------------------------
try:
torch.save(trained_model.state_dict(), "text_detection_model_files/torch_text_detection_model_state_dict.pt")
print('Torch model state dict saved successful')
except Exception as e:
print('Cannot export as PyTorch Format with state dict -- Error : ' + str(e))
# ----------------------------------
# onnx
# ----------------------------------
try:
input_batch = next(iter(val))
input_sample = {
"pixel_values": input_batch["pixel_values"][0].unsqueeze(0),
}
values = input_sample['pixel_values']
file_path = "text_detection_model_files/torch_text_detection_model.onnx"
torch.onnx.export(trained_model, values, file_path,
input_names=['pixel_values'],
output_names=['logits', 'pred_boxes'],
dynamic_axes={'pixel_values': {0: 'batch_size', 1: 'channels', 2: 'width', 3: 'height'},
'logits': {0: 'batch_size'}, 'pred_boxes': {0: 'batch_size'}},
export_params=True, opset_version=11,
enable_onnx_checker=True, verbose=False)
print('Onnx model saved successful')
print('Start model quantization')
model_quant = "text_detection_model_files/torch_text_detection_model.quant.onnx"
quantized_model = quantize_qat(file_path, model_quant)
print('Quantization succesfull')
except Exception as e:
print('Cannot export as ONNX Format -- Error : ' + str(e))
# Predictions
model = text_detection_model.load_from_checkpoint(checkpoint_path=trainer.checkpoint_callback.best_model_path)
preds = trainer.predict(model, val, return_predictions=True)
print(preds)
@nielsr do you have any idea or recommendations ? ^^