Segmentation of COVID-19 Chest CT Scans using TransUNet
Transformers have shown great performance in the field of Image Segmentation. It would be amazing to train them using JAX/Flax. The goal of this project is automatic COVID-19 lesion segmentation from chest CT.
2. Language
The model will be trained in English.
3. Model
The TransUNet model will be used for segmentation.
4. Datasets
Possible links to publicly available datasets include:
I’ve written many UNets and data augmentations in both TF and PyTorch as part of my daily job and could bring that experience to the team, helping with the model or data augmentation as needed.
On the other hand I haven’t seen much Jax, dealt with many medical datasets, nor combined transformers with UNet architectures – all of which I’d love to correct. The project outcome is also clearly beneficial, although one should point out that recent analyses have found much Covid-related CV research quite problematic (Nature article).
@dom-miketa same here, I’ve also worked with TensorFlow and PyTorch for segmentation. As JAX is getting popular nowadays it would be amazing to learn it for segmentation tasks.
Yes, you’re right indeed I’ve read that paper but what confused me is if they’ve actually been able to reproduce the result because not all datasets are publicly available. Apart from that, TransUNet can be used for any medical image data. So, we can take advantage of it in either way.
Btw I live in the Bangladesh Standard Time (GMT+6). Looking forward to working with you.
Lots of interest here - awesome! Let’s define this project I think this is one of the more difficult projects since I’m not sure whether there is an open-source implementation of TransUNet in JAX @awsaf49 ? Is it maybe possible to use another model instead? ViT maybe?
Just a thought, @patrickvonplaten and @awsaf49 , in case we need pixel level classification, I guess ViT wont be that helpful, unless we are okay with a patch level segmentation. I saw that BigBird implementation is available in JAX, maybe we could go with that since we could flatten the images to pixel level? The images seem to be 512 x 512. BigBird might need a lot of samples though since it needs to be trained from scratch…
@patrickvonplaten I think the main purpose of this project is to do pixel-level classification or segmentation with transformer. If we can do segmentation with ViT with then I think we should use ViT too. As far as I know, ViT is for classification not segmentation. If there is no implementation of TransUNet in JAX perhaps we can work on it? It would be amazing if we can implement it in this project…