Hello all,
I have tried my best to organize FLAX/JAX learning resources. I hope it will be useful for all of us.
If you have more learning resources, drop a comment & I will add them here…
JAX is a numerical computation library that exposes a NumPy-like API with tracing capabilities.
Flax is the most widely used JAX library, with 129 dependent projects as of May 2021.
Flax builds on top of JAX with an ergonomic module abstraction using Python dataclasses that leads to concise and explicit code.
Flax’s “lifted” JAX transformations (e.g. vmap
, remat
) allow you to nest JAX transformation and modules in any way you wish. It is also the library underlying all of the official Cloud TPU JAX examples.
Github links :
JAX Projects
FLAX examples
Colab links:
Causal Language Modelling on OSCAR
Masked Language Modelling on OSCAR
Text Classification on GLUE
Docs:
JAX 101 Tutorial
Flax Basics Colab
Flax examples
More useful links:
Guide to run JAX on Google Cloud TPU.
Load Dataset in streaming mode
TALKS Date & Time:
30.6.2021 : 9am - 11am PST Zoom link
1.7.2021 : 8:30am - 10am PST Zoom link
2.7.2021 : 8am - 10am PST Zoom link