I want to fine tune DiT
for object detection (text, diagrams detection only) etc for my own dataset. Been searching through the web for quite some time but could not find anything on fine tuning a Transformers backbone for object detection.
-
This github issues for DETR for custom backbone describes how to change the backbone as the author said that **you can use ANY models from
timm
library and since there are almost 890 models present but unfortunately, notDiT
. -
DiT
is also present as a HuggingFace model asmicrosoft/dit-large
and supports Feature Extraction asBeitFeatureExtractor.from_pretrained("microsoft/dit-large")
so I think it could be used as a backbone but I found nothing on this one either.
I tried changing the code on how to train DETR on custom data by replacing code in Cell 8,
#feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/dit-large")
but while running the code for Cell 11,
from torch.utils.data import DataLoader
def collate_fn(batch):
pixel_values = [item[0] for item in batch]
encoding = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
labels = [item[1] for item in batch]
batch = {}
batch['pixel_values'] = encoding['pixel_values']
batch['pixel_mask'] = encoding['pixel_mask']
batch['labels'] = labels
return batch
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2)
batch = next(iter(train_dataloader))
it gave me error as:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-11-446d81c845dd> in <module>
13 train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
14 val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2)
---> 15 batch = next(iter(train_dataloader))
5 frames
/usr/local/lib/python3.7/dist-packages/transformers/feature_extraction_utils.py in __getitem__(self, item)
85 """
86 if isinstance(item, str):
---> 87 return self.data[item]
88 else:
89 raise KeyError("Indexing with integers is not available when using Python based feature extractors")
KeyError: 'labels'