Added you guys!
I’d love to help if there are still slots available.
I would like to join this project, please add me!
Hi guys, I did some research in the past few days and here are my learnings.
I hope it will help you understand better how the model works.
Feel free to comment/correct me or suggest other methodologies for our model.
- It is used to create correlation between images and text through similar embeddings.
- It is trained by maximizing product between a pair of image and text and minimizing it between non-associated pairs (this method is also called contrastive learning).
- There is a HF transformer script (look in particular at last section “FLAX_CLIP_MODEL_DOCSTRING” for interesting sample code).
- There is also a JAX Pretrained model “openai/clip-vit-base-patch32” which we may be able to use directly (needed for inference only, see section “Our model” below).
- I think the HF team is in the process of adding a fine-tuning script for CLIP: see TODO but I’m not sure we will need to fine-tune it.
- This is a fancy (and efficient) method using GAN’s to encode/decode images into a sequence of tokens (that are associated to indexes of a learnt codebook).
- We can either use the one from taming-transformers or the one from Suraj which is already implemented in JAX (looks to be already trained looking at the colab from his repo).
- DALL-E can generate images from text.
- They first train an image encoder/decoder (called dVAE) that creates a grid of 32x32 tokens.
- Then they create inputs made of pairs text/image where text is encoded with a BPE (up to 256 tokens to encode the text) and append the 32x32 (=1024) image tokens from the dVAE.
- After that they train a very large model on this sequence so that eventually, we can just input text and it generates image tokens.
- By generating a few sample images, they can then rank the most correlated to the input text using CLIP (see above) and show the best ones.
- Their dataset was “Conceptual Captions” + YFCC100M subset (see post from Arpit)
- Their resource requirements are huge due to lots of input tokens + large model.
We can do similar to DALL-E but simplify it:
- text encoding: we could use the GPT2 tokenizer (no training required)
- image encoding: we can use VQGAN-JAX from Suraj which uses much less tokens than DALL-E (I believe it’s similar to the taming-transformers model). We may not need any fine-tuning.
- concatenate text and image tokens
- train the FLAX-GPT2 model with this script
- to improve inference, we can use a pre-trained CLIP to rank best generated images
Please feel free to get started and already explore.
I’m not going to be able to work on this project a lot until Tuesday but I’ll follow the forum/discourse for feedback/brainstorming.
- experiment with flax/jax and setup of the TPU instance that we should get shortly
- work on dataset loading - see suggested datasets
- Optionally create the OpenAI YFCC100M subset (see this post)
- work on text/image encoding
- concatenate inputs (not sure if we need fixed length for text or use a special token separating text & image)
- adapt training script
- create inference function
- integrate CLIP for better results (only if we have the time)
- work on a demo (streamlit or colab or maybe just HF widget)
- document (set up repo on model hub per instructions, start on README writeup…)
- help with coordinating activities & progress
That’s a really nice write-up @boris
This group has a lot of participation
Giving you guys directly to TPUs tomorrow! Split the team randomly into two in the official google sheet, but this shouldn’t change anything - just that you have access two 2 TPU v3-8s
Might make organization a bit easier to split work on two VMs!
I’d be very interested in joining this project, if you’re interested in having another person contributing! I’ve posted in the Discord, but I’ll send what I sent there here as well.
I’m a PhD student working on related things for my research, with a good amount of experience training text-to-image transformers (usually on the order of ~300M parameters). I’ve currently been working in PyTorch, but I would be excited to try Jax/Flax and to scale up these models!
Sorry to ping, can you please add me too?
The model was trained on publicly available text-image pairs collected from the internet. This data consists partly of Conceptual Captions and a filtered subset of YFCC100M. We used a subset of the filters described in Sharma et al. to construct this dataset; further details are described in our paper. We will not be releasing the dataset.
Also, they released CLIP trained on same YFCC100M dataset and later they added the subset details used for CLIP.
The subset contains 14,829,396 images, about 15% of the full dataset and showed that with this subset the performance remained largely same in case of CLIP.
What if same subset of YFCC100M was used to train DALL.E ?
Anyways as the dataset is publicly accessible i think you might be interested in it.
Excited to see the end result. Cheers !!
Awesome! Thanks for the info @tuner007
hii can you let me in also
See here: Slack
Just pinging you @patrickvonplaten to see if i need to do anything else for the TPU access (I already filled the sign up sheet).
Current Status Summary
- on github
- on huggingface - we’ll push from github at the end + add models
- Workflow: I’m adding everyone as collaborator on the github (send me your username). As we need to be fast I suggest that we do “PR + 1 approval from anybody = merge to main branch”. Small updates (typo’s, quick bug fix, readme…) may not even need approval but just notify on the discord
- input is tokenized text (with a text encoder)
- output is tokenized image (with VQGAN)
- Conceptual 12M data prepared by @greeneggsandyaml
- Conceptual 3M data prepared by @khalidsaifullaah
- YFCC100M: I’m working on creating the OpenAI subset on my local machine (looking good so far, I expect 2TB max). If it works I’ll try to upload to datasets for streaming, I created a post to see if it’s feasible
- Can somebody prepare a mini dataset that can be easily shared with others and used for colab prototyping of the different tasks?
- there is an existing jax model
- needs to be finetuned on our dataset
- ideally we need to finish by Friday latest so we have at least a week of training for our full model (which will give us the time to finalize our scripts in parallel)
- for people working on other tasks, just use pre-trained model for now (refer to Suraj model). This will be our VQGAN if we don’t successfully fine-tuning it in time
- select a base model, non-autoregressive + check it handles positioning
- can we find a good pre-trained model that does not need fine-tuning (I imagine we would freeze it)
- Maybe we can adapt jax/hybrid-clip scripts - Suraj mentioned their efficient data loading
- loading data logic
- loss definition + hyperparameters (research similar papers)
- based on how long it takes to generate images, we could sample from a few and re-rank them with existing OpenAI CLIP
- create inference function
- it would be cool for our demo to work with huggingface widgets (PR in progress)
As usual, feel free to choose where you want to help!
Finally let’s schedule a call with Suraj.
From his calendar, the best for me would be anytime after 8AM Pacific Time. What would work for you?
I have created a small subset with 10 thousand images extracted from CC12M. I’m planning to use it to test local VQGAN training in my GPU, as Suraj said doing it on the TPU could be harder.
Regarding the call, I won’t be able to make it today unfortunately (still traveling).
That’s great, do you think you could share an even smaller version as a file that can be downloaded in colab for people to experiment easily?
Call scheduled with Suraj: see google calendar
I created a tiny subset with 512 images (~25MB).
Notes from meeting with Suraj:
- we should try training the VQGAN on a TPU VM, no need for full YFCC100M subset - let’s just use our existing dataset + part of YFCC100M that can fit on VM
- script for VQGAN uses pytorch lightning, if we can update it then we could take advantage of wandb for automatically pushing checkpoints (recent feature of the callback)
- for the full model, it may be better not to freeze the text as the pretrained encoder is trained on different type of data
- there is some existing Seq2Seq script that we should be able to directly adapt
- we give input ids (raw text) + output ids (image encoded by VQGAN)
- since output is different from pretrained model, it will set random weights so we need to reload manually the encoder part from a pretrained model
- we should build the dataset with preprocessed images (encoded with VQGAN) so the data loading is faster
Things I forgot to ask:
- text preprocessing
- data has title + description + usertags
- should we concatenate it all or just keep description or title (need to explore)
- I tend to think of either keeping just description as this is what a user may input or maybe a random mix of all
- See example field here