Segmentation of COVID-19 Chest CT Scans using TransUNet

Segmentation of COVID-19 Chest CT Scans using TransUNet

Transformers have shown great performance in the field of Image Segmentation. It would be amazing to train them using JAX/Flax. The goal of this project is automatic COVID-19 lesion segmentation from chest CT.

2. Language

The model will be trained in English.

3. Model

The TransUNet model will be used for segmentation.

4. Datasets

Possible links to publicly available datasets include:

5. Training scripts

Are there publicly available training scripts that can be used/tweaked for the project?
Answer: Unfortunately, No.

6. (Optional) Challenges

  • Build TransUnet model in JAX/Flax
  • Implement Augmentation for Segmentation.

7. (Optional) Desired project outcome

  • This project will help clinicians segment COVID-19 lesions which will help clinicians identify potential COVID-19 cases.
  • TransUNet can be later on other segmentation tasks.
7 Likes

I’d love to join this project!

I’ve written many UNets and data augmentations in both TF and PyTorch as part of my daily job and could bring that experience to the team, helping with the model or data augmentation as needed.

On the other hand I haven’t seen much Jax, dealt with many medical datasets, nor combined transformers with UNet architectures – all of which I’d love to correct. The project outcome is also clearly beneficial, although one should point out that recent analyses have found much Covid-related CV research quite problematic (Nature article).

I live in the London timezone (BST, GMT+1).

1 Like

@dom-miketa same here, I’ve also worked with TensorFlow and PyTorch for segmentation. As JAX is getting popular nowadays it would be amazing to learn it for segmentation tasks.
Yes, you’re right indeed I’ve read that paper but what confused me is if they’ve actually been able to reproduce the result because not all datasets are publicly available. Apart from that, TransUNet can be used for any medical image data. So, we can take advantage of it in either way.
Btw I live in the Bangladesh Standard Time (GMT+6). Looking forward to working with you.

1 Like

I would love to be part of this project

1 Like

I would love to be a part of this project!

1 Like

This seems really interesting. Would love to be part of the project.

1 Like

I would love to join this project.

1 Like

Hey, have been also doing a lot of deep learning in healthcare space, would love to join this, if there is any spots!

Lots of interest here - awesome! Let’s define this project :slight_smile: I think this is one of the more difficult projects since I’m not sure whether there is an open-source implementation of TransUNet in JAX @awsaf49 ? Is it maybe possible to use another model instead? ViT maybe?

1 Like

Just a thought, @patrickvonplaten and @awsaf49 , in case we need pixel level classification, I guess ViT wont be that helpful, unless we are okay with a patch level segmentation. I saw that BigBird implementation is available in JAX, maybe we could go with that since we could flatten the images to pixel level? The images seem to be 512 x 512. BigBird might need a lot of samples though since it needs to be trained from scratch…

@patrickvonplaten I think the main purpose of this project is to do pixel-level classification or segmentation with transformer. If we can do segmentation with ViT with then I think we should use ViT too. As far as I know, ViT is for classification not segmentation. If there is no implementation of TransUNet in JAX perhaps we can work on it? It would be amazing if we can implement it in this project…

I don’t think there’s a problem though – the TransUNet is ‘just’ a UNet with a ViT encoder. And the dataset is publicly available. :slight_smile:

Edit: That said, I’ve decided to join another project – I’m a bit too comfortable in segmentation and would like to try something new!