FlaxVisionEncoderDecoderModel decoder_start_token_id

Im following the documentation here to instantiate a FlaxVisionEncoderDecoderModel but am unable to do so.

I’m on Transformers 4.15.0

from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# load output tokenizer
tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2")

# initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained("vit", "gpt2")


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-3-d9e3c9f46932> in <module>
      5 
      6 # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
----> 7 model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained("vit", "gpt2")
      8 
      9 pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values

AttributeError: type object 'FlaxVisionEncoderDecoderModel' has no attribute 'from_encoder_decoder_pretrained'

Hi. Do you have Flax/Jax installed on your computer? It’s required in order to use FlaxVisionEncoderDecoderModel.
(There should have a better error message for this situation, and it will be fixed.)