DALL-E - mini version

Just pinging you @patrickvonplaten to see if i need to do anything else for the TPU access (I already filled the sign up sheet).

Current Status Summary

Repo

  • on github
  • on huggingface - we’ll push from github at the end + add models
  • Workflow: I’m adding everyone as collaborator on the github (send me your username). As we need to be fast I suggest that we do “PR + 1 approval from anybody = merge to main branch”. Small updates (typo’s, quick bug fix, readme…) may not even need approval but just notify on the discord

General Architecture

  • Seq2Seq
  • input is tokenized text (with a text encoder)
  • output is tokenized image (with VQGAN)

Datasets

  • :white_check_mark: Conceptual 12M data prepared by @greeneggsandyaml
  • :white_check_mark: Conceptual 3M data prepared by @khalidsaifullaah :partying_face:
  • :black_square_button: YFCC100M: I’m working on creating the OpenAI subset on my local machine (looking good so far, I expect 2TB max). If it works I’ll try to upload to datasets for streaming, I created a post to see if it’s feasible
  • :black_square_button: Can somebody prepare a mini dataset that can be easily shared with others and used for colab prototyping of the different tasks?

VQGAN

  • :information_source: there is an existing jax model
  • needs to be finetuned on our dataset
    • :black_square_button: @lkhphuc is trying to make a jax training script (no existing one available)
    • :black_square_button: alternatively we can use taming-transformers to train on custom dataset and convert to jax: I may be able to try it but any volunteer would be appreciated (on their local GPU or on our TPU VM)
  • ideally we need to finish by Friday latest so we have at least a week of training for our full model (which will give us the time to finalize our scripts in parallel)
  • for people working on other tasks, just use pre-trained model for now (refer to Suraj model). This will be our VQGAN if we don’t successfully fine-tuning it in time

Text encoder

  • :black_square_button: select a base model, non-autoregressive + check it handles positioning
  • :black_square_button: can we find a good pre-trained model that does not need fine-tuning (I imagine we would freeze it)

Seq2Seq

  • :information_source: Maybe we can adapt jax/hybrid-clip scripts - Suraj mentioned their efficient data loading
  • :black_square_button: loading data logic
  • :black_square_button: loss definition + hyperparameters (research similar papers)

Demo

  • :black_square_button: based on how long it takes to generate images, we could sample from a few and re-rank them with existing OpenAI CLIP
  • :black_square_button: create inference function
  • :black_square_button: it would be cool for our demo to work with huggingface widgets (PR in progress)

As usual, feel free to choose where you want to help!

Finally let’s schedule a call with Suraj.
From his calendar, the best for me would be anytime after 8AM Pacific Time. What would work for you?

7 Likes

Excellent summary!

I have created a small subset with 10 thousand images extracted from CC12M. I’m planning to use it to test local VQGAN training in my GPU, as Suraj said doing it on the TPU could be harder.

Regarding the call, I won’t be able to make it today unfortunately (still traveling).

1 Like

That’s great, do you think you could share an even smaller version as a file that can be downloaded in colab for people to experiment easily?

Call scheduled with Suraj: see google calendar

I created a tiny subset with 512 images (~25MB).

3 Likes

Notes from meeting with Suraj:

  • we should try training the VQGAN on a TPU VM, no need for full YFCC100M subset - let’s just use our existing dataset + part of YFCC100M that can fit on VM
  • script for VQGAN uses pytorch lightning, if we can update it then we could take advantage of wandb for automatically pushing checkpoints (recent feature of the callback)
  • for the full model, it may be better not to freeze the text as the pretrained encoder is trained on different type of data
  • there is some existing Seq2Seq script that we should be able to directly adapt
    • we give input ids (raw text) + output ids (image encoded by VQGAN)
    • since output is different from pretrained model, it will set random weights so we need to reload manually the encoder part from a pretrained model
    • we should build the dataset with preprocessed images (encoded with VQGAN) so the data loading is faster

Things I forgot to ask:

  • text preprocessing
    • data has title + description + usertags
    • should we concatenate it all or just keep description or title (need to explore)
    • I tend to think of either keeping just description as this is what a user may input or maybe a random mix of all
    • See example field here
1 Like

There’s another effort to recreate DALL-E: DALL-E

There’s also an effort to get text-image pairs from Common Crawl there, currently at 100 million. :slight_smile:

4 Likes

With this, our currently released models is located here: GitHub - robvanvolt/DALLE-models: Here is a collection of checkpoints for DALLE-pytorch models, from where you can keep on training or start generating images.
Along with the inference colab.

3 Likes

Status

  • :white_check_mark: datasets for VQGAN ready (thanks Khalid), training being set up (Boris & Pedro)
  • :stop_sign: debugging of VQGAN on TPU aborted (Pedro may do some last attempt with help of Tanishq)
  • :black_square_button: test script to convert a VQGAN (try a pretrained one) to JAX from Suraj repo
  • :black_square_button: prepare a function that turns an image into encoded tokens from VQGAN (test with pretrained model)
  • :black_square_button: create a dataset that contains only target text + image tokens (if it’s small we can do it on all the images we have access to) - I suggest to try datasets with map (I would not batch it as the intermediate dataset with loaded images may be too big and it may be the bottleneck)
  • :black_square_button: prepare the seq2seq jax script (test an example first and adapt with one of our dummy datasets)
  • :black_square_button: prepare demo

A lot of these tasks can be done in parallel!
We can coordinate on the discord.

5 Likes

Status

  • :white_check_mark: VQGAN trained - Pytorch model
  • :white_check_mark: conversion script + VQGAN converted to JAX - JAX model
  • :white_check_mark: image encoding script - thanks @pcuenq
  • :white_check_mark: images from CC3M and CC12M pre-encoded with VQGAN - thanks @pcuenq and @khalidsaifullaah - Can any of you check that we dropna for non-existing images on both datasets
  • :black_square_button: images from YFCC100M pre-encoded with VQGAN - @khalidsaifullaah if you are interested
    • file is metadata.jsonl and can be read with pandas.read_json("metadata.jsonl", orient="records", lines=True)
    • explore and see if we prefer text_clean or description_clean or any other logic
    • some files are not present on our VM so we will need to remove those items
  • :white_check_mark: dataset pipeline - thanks @pcuenq - Can we concatenate datasets by passing a list of files?
  • :white_check_mark: model inference - thanks @lkhphuc for the help
  • :black_square_button: prepare the seq2seq jax script - I already started it but didn’t have the time to test: see here
    Note: I think for now we should have predict_with_generate=False as we would need to define metrics (not obvious to find anything better than the loss here). When it runs we can just log some sample predictions decoded with the VQGAN.
  • :black_square_button: finetune learning rate and warmup steps
  • :black_square_button: final training
  • :black_square_button: test generate function - should already handle properly bos/eos/pos tokens in decoder based on config
  • :black_square_button: make a JAX VQVAE - use haiku implementation + Suraj VQGAN repo as an example
  • :black_square_button: prepare demo - we may want to generate several images and use CLIP for reranking

Thanks everybody!!!

3 Likes

Status

  • :white_check_mark: Training ongoing for the Seq2Seq (2 active runs doing pretty good) - see dashboard
  • :white_check_mark: Awesome development demo and tests ongoing by everybody
  • :white_check_mark: CLIP integrated by @pcuenq
  • :black_square_button: Final demo to develop
    • UI
      • Option A: huggingface widget - seems like the required PR has not been merged yet
      • Option B: Streamlit
      • Option C: Gradio
      • Option D: Colab
      • Notes:
        • We can potentially get a T4 GPU and host it on huggingface spaces - @valhalla would it be possible?
        • We also need to see if those options support well JAX (I would think so except maybe for huggingface widget)
      • I think it’s important to pursue different options and see what works the best in the end
    • Inference
      • Reference: see the demo folder and the colab from @lkhphuc
      • Check how we can make it faster in JAX (maybe some compilation tricks)
      • Test Pytorch vs JAX inference (regular Bart summarization model) and if it’s worth it, we’ll convert our model to Pytorch (need to recreate the architecture + transfer correctly the weights)
      • Push the model to the hub as it could potentially be loaded faster if the app is hosted by HF
  • :black_square_button: Clean up repo - not sure if we want to bring our fork of taming-transformers as well
  • :black_square_button: Writeup - I would suggest to make the main writeup as a W&B report since it will link all our runs together and have a clean repo for reproducibility

:partying_face: Congrats everybody for the great results we’re getting!

5 Likes

Status update

  • Demo

    • We’ll design a streamlit app in an “app” folder
      • → let us keep app and code together as they evolve
      • → easier for several member to contribute (branches, PR’s…)
      • → let us experiment locally
    • We’ll force push our repo to huggingface spaces - only the README seems to have a few important tags at the top and the rest can be anything (we don’t have any content for now, we’ll put link to app + link to report + setup instructions)
    • Once it’s set up, we can probably set up a Github action so our repo is pushed automatically to huggingface spaces
    • we should push our best checkpoint (even if we update it later) to the hub so we can reference it in the hub (I imagine it will load faster as it’s probably the same servers) - We should reference the model commit id in case someone pushes a new version on our repo by mistake
    • I believe @ghosh-r has started prototyping something so he can start pushing his code
  • Report

    • We’ll do it through W&B. See current report here
    • I’m checking how to add the contributors as authors but you should have access to edit it
    • @pcuenq is taking care of the predictions section (right now it’s automatically updated with our latest run :heart_eyes:)
    • @khalidsaifullaah is making a super cool graph
  • Documentation

    • update repo README
    • model cards VQGAN + mini-DALLE
  • Ongoing runs

    • Current best run
    • Alternative run - starting from a checkpoint and with higher learning rate - expect the loss to go up for a few hours and then it should go down
7 Likes

Status update

  • Runs

    • Our current best run has a decreasing eval loss
    • However I checked manually and it seems it now predicts only the same token 10042. Can somebody double check based on wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest
    • If that’s the case, we may have to pick manually the best performing model and find out what this happened - @pcuenq could you use your script to loop over all versions of wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest
  • Demo

    • @tmabraham has created a cool Gradio demo. The UI is awesome but the inference is slow (however not slower than our Streamlit version). Once we can test it on HF spaces we may have to adapt how many images are generated.
    • We are having trouble installing dependencies on our HF space. @pcuenq has reached out to the HF team through slack so they can help us resolve the issue (seems to be related to the cuda toolkit).
    • We can set up some suggested prompts so let’s play with it and find 5-8 cool prompts to suggest (can be great for our report too)!
  • Repos

    • I created some basic cards for our datasets and models
    • we need to refactor a bit our repo. Ideally the root will only have the readme, requirements.txt for the app (needs to be there), an app folder, a dev folder (with our notebooks) and the rest in the same folder (whether “src” or “dalle-mini”).
    • our app should pull the model definition from the repo with something like from ../dalle-mini/model import CustomFlaxBartForConditionalGeneration)
    • it would be cool to force push our repo to HF space and set up a Github action to always push our master branch there
  • Writeup

    • @khalidsaifullaah has created some awesome diagrams of our architecture
    • the report will be in W&B and will link to our runs. Pedro has already prepared some cool stuff for our predictions section.
4 Likes

Status update

  • Our report is ready and looks awesome! Possible improvements:

    • top pic in intro could be a bit better - feel free to generate something interesting
    • generated samples could potentially be better - @pcuenq if you have a chance you could try latest model with 128 samples
  • Demo

    • Works locally and looks pretty great!
    • Pushed to HF space - we will need their help for set up
    • The screenshot has the white space on the top - maybe @tmabraham can check with Gradio?
    • Maybe we can remove the example prompts (unless it would be a drop down in the input) or we can just have a cool default prompt
  • Repo

    • I did some big cleanup of the repo - see PR
4 Likes

Status update

  • Demo almost ready (just need to upgrade streamlit with HF when possible)

  • Report almost ready (let’s have some final read and check links, etc)

  • Repo

    • Use the model definition in our training script and notebooks (see app)
    • Clean up our notebooks - remove the useless ones and have some simple ones - inference (regular + TPU), use of VQGAN, how to encode an image
3 Likes

In addition to my previous comments, see the evaluation form that has great ideas from Patrick on what we can improve!

2 Likes

Status update

Almost done :partying_face:

  • Demo

    • :white_check_mark: Github repo now in sync with Spaces
    • :black_square_button: update streamlit on HF Spaces (and remove our hacks)
  • Report

    • :white_check_mark: Evolution of predictions complete
    • :white_check_mark: Updated section on limitations and biases - Please review and see if you have more to add
    • :black_square_button: compare to DALLE-pytorch - I could not find any generic model, there’s some trained on tiny datasets (for example birds) which would not be great for comparison. Can you find any?
  • Model cards

    • :black_square_button: VQGAN JAX card being updated by @pcuenq
    • :black_square_button: VQGAN Pytorch version to complete (I can copy from JAX version + add details on requirements and inference)
    • :black_square_button: complete DALL-E mini card (inference script could be nice and does not need to include CLIP)
  • Repo

    • :black_square_button: cleanup notebooks - I think we just need one for encoding data + one for inference
5 Likes

Status

  • :white_check_mark: Report to be released soon
  • :white_check_mark: All model cards updated (thanks Pedro)
  • :white_check_mark: Inference colab ready
  • :black_square_button: Demo - waiting for the upgrade of streamlit to fix our layout
  • :orange_square: New run ongoing: same data, more epochs
  • :black_square_button: working on dataset loading script for YFCC100M - see here
  • :black_square_button: trying to fix optimizer checkpoint - see here

Let’s plan next steps soon.

1 Like

For those following and don’t know, the demo has already been released.

Feel free to join the Dall-E Discord server to help with this project!