Equivalent of `inputs_embeds` for `FlaxGPT2Model`

Hi,
The forward method of GPT2Model allows the user to pass the embedded representation of tokens directly to GPT2 using the inputs_embeds keyword. However I notice that the __call__ method of FlaxGPT2Model does not seem to accept this keyword. Is there a workaround to achieve the same functionality? Thanks.

1 Like