__call__() got an unexpected keyword argument 'special_tokens_mask' when running run_t5_mlm_flax.py

Hi, I am getting the following error while pre-training t5-base model on my custom dataset from the t5-base checkpoint. I was using the run_t5_mlm_flax.py link script present in the transformers/examples/flax directory.
I was running the following command to run the script:

python run_t5_mlm_flax.py --output_dir <output path> --model_name_or_path t5-base --train_file <train file path> --validation_file <validation file path> --do_train --do_eval --num_train_epochs 1 --max_seq_length 512 --per_device_train_batch_size 16 --per_device_eval_batch_size 16 --weight_decay 0.001 --warmup_steps 2000 --overwrite_output_dir --logging_steps 500 --seed 1 --save_steps 500 --eval_steps 500

Here the <train file path> is the path to the training .txt file which contains sentences on each line.
Similarly for validation as well.

and I got this error:

STACK TRACE

Traceback (most recent call last):
  File "run_t5_mlm_flax.py", line 964, in <module>
    main()
  File "run_t5_mlm_flax.py", line 877, in main
    state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2156, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2038, in pmap_f
    global_arg_shapes=p.global_arg_shapes_flat)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2040, in bind
    return map_bind(self, fun, *args, **params)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2072, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2043, in process
    return trace.process_map(self, fun, tracers, params)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 910, in xla_pmap_impl
    *abstract_args)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 937, in parallel_callable
    in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 1109, in lower_parallel_callable
    pci, fun, global_arg_shapes)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 1016, in stage_parallel_callable
    fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2177, in trace_to_jaxpr_final
    fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2109, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "run_t5_mlm_flax.py", line 813, in train_step
    loss, grad = grad_fn(state.params)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 1071, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2517, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/ad.py", line 133, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/ad.py", line 122, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 802, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "run_t5_mlm_flax.py", line 805, in loss_fn
    logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
jax._src.traceback_util.UnfilteredStackTrace: TypeError: __call__() got an unexpected keyword argument 'special_tokens_mask'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "run_t5_mlm_flax.py", line 964, in <module>
    main()
  File "run_t5_mlm_flax.py", line 877, in main
    state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
  File "run_t5_mlm_flax.py", line 813, in train_step
    loss, grad = grad_fn(state.params)
  File "run_t5_mlm_flax.py", line 805, in loss_fn
    logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
TypeError: __call__() got an unexpected keyword argument 'special_tokens_mask'