Hey @boris
That’s a great idea, I’m super excited about this!
Regarding the dataset, here are a few that can be used for this
- Conceptual Caption - ~3.3M image-text pairs
- Conceptual 12M - 12M
- WIT- ~37M
Also regarding architecture, seems using a simple scaled-up GPT2 as LM can also give good results (see CogView). By simple, I mean no sparse attention or row-column attentions, etc which is used in the DALL-E I guess. And GPT2 is already available in JAX. And JAX is way faster on TPU than PT
And if you look at this discussion it seems an even smaller model could give good enough results on the domain-specific dataset.
Also as suggested in the discussion above using the VQGAN from taming-transformers as the images tokenizer can further reduce the complexity of training such models as the max image token length for these VQGAN models is 256 (way less than DALL-E’s VQVAE, which uses 1024), so overall 256 text tokens and 256 image tokens = 512, should be manageable on a single v3-8.