On Google Colab, trying to use FlaxBigBirdForMaskedLM shows that Flax is not installed.
from transformers import BigBirdConfig
config = BigBirdConfig(
vocab_size=40000,
hidden_size = 768,
max_position_embeddings=16000,
num_attention_heads=4, #6
num_hidden_layers=4, #6
)
from transformers import FlaxBigBirdForMaskedLM
model = FlaxBigBirdForMaskedLM(config=config)
It may be due to some other issue, But I am consistenly getting this error across all runtimes (target runtime is TPU).
ImportError:
FlaxBigBirdForMaskedLM requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
Flax is indeed installed, via !pip install -q transformers[flax].
Does this seem like a genuine bug? The problem is that the backend Flax is apparently not accessible by the method, while I can easily import flax and other utilities.
Ahh, nvm - it was just jax not being installed due to a commented cell.
Anyways, I had a quick question - the object returned by FlaxBigBirdForMaskedLM doesn’t seem to be a model; it doesn’t have the .num_parameters() method.
<transformers.models.big_bird.modeling_flax_big_bird.FlaxBigBirdForMaskedLM at 0x7fcbeda59850>
Assuming its a difference in frameworks and moving on, further down the line I encounter:-
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-22-cc7b30181d8f> in <module>()
29 data_collator=data_collator,
30 train_dataset=mapped_dataset['train'],
---> 31 eval_dataset=mapped_dataset['test']
32 )
/usr/local/lib/python3.7/dist-packages/transformers/trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)
361
362 if self.place_model_on_device:
--> 363 model = model.to(args.device)
364
365 # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
AttributeError: 'FlaxBigBirdForMaskedLM' object has no attribute 'to'
Clearly having some methods missing. Assuming we use the pytorch version, that should be a valid statement unless.
What might be wrong is perhaps not being able to init the TPU device properly (I get a warning) and maybe that might play havoc in case it doesn’t fall back on CPU properly?
Prolly why I don’t like Huggingface very much - its just too fragmented for useful customization (though to its credit, it does warn about that con in the readme). I would be getting a host of new errors now with new scripts
Why do we need to use the Flax version for running BigBird on TPU
Why did Google opt to release BigBird on HF, rather than a standalone Pytorch/Jax repo?