Hi, I am getting the following error while pre-training t5-base model on my custom dataset from the t5-base checkpoint. I was using the run_t5_mlm_flax.py link script present in the transformers/examples/flax directory.
I was running the following command to run the script:
python run_t5_mlm_flax.py --output_dir <output path> --model_name_or_path t5-base --train_file <train file path> --validation_file <validation file path> --do_train --do_eval --num_train_epochs 1 --max_seq_length 512 --per_device_train_batch_size 16 --per_device_eval_batch_size 16 --weight_decay 0.001 --warmup_steps 2000 --overwrite_output_dir --logging_steps 500 --seed 1 --save_steps 500 --eval_steps 500
Here the <train file path>
is the path to the training .txt file which contains sentences on each line.
Similarly for validation as well.
and I got this error:
STACK TRACE
Traceback (most recent call last):
File "run_t5_mlm_flax.py", line 964, in <module>
main()
File "run_t5_mlm_flax.py", line 877, in main
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2156, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2038, in pmap_f
global_arg_shapes=p.global_arg_shapes_flat)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2040, in bind
return map_bind(self, fun, *args, **params)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2072, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 2043, in process
return trace.process_map(self, fun, tracers, params)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 910, in xla_pmap_impl
*abstract_args)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 937, in parallel_callable
in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 1109, in lower_parallel_callable
pci, fun, global_arg_shapes)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/pxla.py", line 1016, in stage_parallel_callable
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2177, in trace_to_jaxpr_final
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2109, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_t5_mlm_flax.py", line 813, in train_step
loss, grad = grad_fn(state.params)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 1071, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/api.py", line 2517, in _vjp
flat_fun, primals_flat, reduce_axes=reduce_axes)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/ad.py", line 133, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/ad.py", line 122, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 802, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/users/PAA0201/rsan/.conda/envs/WikiHowProject/lib/python3.7/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_t5_mlm_flax.py", line 805, in loss_fn
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
jax._src.traceback_util.UnfilteredStackTrace: TypeError: __call__() got an unexpected keyword argument 'special_tokens_mask'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_t5_mlm_flax.py", line 964, in <module>
main()
File "run_t5_mlm_flax.py", line 877, in main
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "run_t5_mlm_flax.py", line 813, in train_step
loss, grad = grad_fn(state.params)
File "run_t5_mlm_flax.py", line 805, in loss_fn
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
TypeError: __call__() got an unexpected keyword argument 'special_tokens_mask'