BUG Confirmation: BigBirdLM not able to use Flax

On Google Colab, trying to use FlaxBigBirdForMaskedLM shows that Flax is not installed.

from transformers import BigBirdConfig

config = BigBirdConfig(
    vocab_size=40000,
    hidden_size = 768,
    max_position_embeddings=16000,
    num_attention_heads=4,                   #6
    num_hidden_layers=4,                    #6
)

from transformers import FlaxBigBirdForMaskedLM
model = FlaxBigBirdForMaskedLM(config=config)

It may be due to some other issue, But I am consistenly getting this error across all runtimes (target runtime is TPU).

ImportError: 
FlaxBigBirdForMaskedLM requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.

Flax is indeed installed, via !pip install -q transformers[flax].

Does this seem like a genuine bug? The problem is that the backend Flax is apparently not accessible by the method, while I can easily import flax and other utilities.

cc @vasudevgupta who i believe is the one working on bigbird and flax

Also looking into the issue! Do you experience this also with othe Flax models like FlaxGPT2Model or FlaxBartModel?

ALso see: [Flax] Add jax flax to env command by patrickvonplaten · Pull Request #12251 · huggingface/transformers · GitHub

1 Like

hmm…what was the environment you guys use for the executing tests? I can try replicating it in colab and see if it solves the issue :thinking:

It still doesn’t work for me, when I trying to install from the master version. Can you reproduce on your end?

I can’t reproduce sadly :-/

I’m using the following env (one can now just run transformers-cli env

- `transformers` version: 4.8.0.dev0
- Platform: Linux-5.4.0-74-generic-x86_64-with-glibc2.27
- Python version: 3.9.1
- PyTorch version (GPU?): 1.8.1+cpu (False)
- Tensorflow version (GPU?): 2.5.0-rc2 (False)
- Flax version (CPU?/GPU?/TPU?): 0.3.4 (cpu)
- Jax version: 0.2.13
- JaxLib version: 0.1.65
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No

Ahh, nvm - it was just jax not being installed due to a commented cell.

Anyways, I had a quick question - the object returned by FlaxBigBirdForMaskedLM doesn’t seem to be a model; it doesn’t have the .num_parameters() method.

<transformers.models.big_bird.modeling_flax_big_bird.FlaxBigBirdForMaskedLM at 0x7fcbeda59850>

Assuming its a difference in frameworks and moving on, further down the line I encounter:-

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-22-cc7b30181d8f> in <module>()
     29     data_collator=data_collator,
     30     train_dataset=mapped_dataset['train'],
---> 31     eval_dataset=mapped_dataset['test']
     32 )

/usr/local/lib/python3.7/dist-packages/transformers/trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)
    361 
    362         if self.place_model_on_device:
--> 363             model = model.to(args.device)
    364 
    365         # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs

AttributeError: 'FlaxBigBirdForMaskedLM' object has no attribute 'to'

Clearly having some methods missing. Assuming we use the pytorch version, that should be a valid statement unless.

What might be wrong is perhaps not being able to init the TPU device properly (I get a warning) and maybe that might play havoc in case it doesn’t fall back on CPU properly?

Should I make a repro Colab?

Should I continue this thread, or do I open a new issue on GitHub? :thinking:
I presume this problem might also be in BERT’s MLM script

You are using a script using the Trainer API, which is a PyTorch API, with a Flax model. This cannot work. The Flax scripts are here.

Prolly why I don’t like Huggingface very much - its just too fragmented for useful customization (though to its credit, it does warn about that con in the readme). I would be getting a host of new errors now with new scripts :frowning:

  1. Why do we need to use the Flax version for running BigBird on TPU
  2. Why did Google opt to release BigBird on HF, rather than a standalone Pytorch/Jax repo?

Not an expert, but from what I know, Google have shared the BigBird Standalone:
google-research/bigbird: Transformers for Longer Sequences (github.com)

ikr, but its in TF which is not much better than HuggingFace. that’s why I asked for Jax lib, since that’s what google developed.

The PyTorch model with the Trainer API runs perfectly fine on TPU. See here for how to launch the PyTorch example scripts on TPUs.