Help on training a TensorFlow model for distilbert-squad

Hi!

My name is Felipe Adachi. I’m an absolute beginner on Pytorch/Transformers, and appreciate any help I could get.

I’m trying to fine-tune a TensorFlow distilbert model on a SQUAD Dataset, but without success. Can anyone tell me if there’s a more direct way of doing this, or help me with my current attempt?

I have tried using the run_tf_squad.py script, but without success.

I have a custom squad dataset in a folder train_file, and I’m calling the script like this:
python run_tf_squad.py --model_name_or_path distilbert-base-uncased --output_dir model --max_seq_length 384 --num_train_epochs 2 --data_dir train_file --use_tfds False --per_gpu_train_batch_size 8 --per_gpu_eval_batch_size 16 --do_train --logging_dir logs --logging_steps 10 --learning_rate 3e-5 --doc_stride 128

But I’m getting this error:

Traceback (most recent call last):

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 897, in generator_py_func
    flattened_values = nest.flatten_up_to(output_types, values)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 396, in flatten_up_to
    assert_shallow_structure(shallow_tree, input_tree)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 323, in assert_shallow_structure
    assert_shallow_structure(shallow_branch, input_branch,

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 308, in assert_shallow_structure
    raise ValueError(

ValueError: The two structures don't have the same sequence length. Input structure has length 5, while shallow structure has length 4.


During handling of the above exception, another exception occurred:


Traceback (most recent call last):

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\ops\script_ops.py", line 249, in __call__
    ret = func(*args)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 620, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 899, in generator_py_func
    six.reraise(

  File "C:\Users\felip\miniconda3\lib\site-packages\six.py", line 702, in reraise
    raise value.with_traceback(tb)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 897, in generator_py_func
    flattened_values = nest.flatten_up_to(output_types, values)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 396, in flatten_up_to
    assert_shallow_structure(shallow_tree, input_tree)

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 323, in assert_shallow_structure
    assert_shallow_structure(shallow_branch, input_branch,

  File "C:\Users\felip\miniconda3\lib\site-packages\tensorflow\python\data\util\nest.py", line 308, in assert_shallow_structure
    raise ValueError(

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was ({'input_ids': tf.int32, 'attention_mask': tf.int32, 'feature_index': tf.int64, 'qas_id': tf.string}, {'start_positions': tf.int64, 'end_positions': tf.int64, 'cls_index': tf.int64, 'p_mask': tf.int32, 'is_impossible': tf.int32}), but the yielded element was ({'input_ids': [101, 1996, 9093, 1011, 2203, 2323, 2022, 8642, 4417, 2007, 1996, 12407, 1520, 22563, 1521, 1012, 102, 1996, 2592, 2089, 2022, 2649, 1999, 1996, 2236, 9254, 2030, 2060, 6254, 1010, 2104, 5227, 1012, 2174, 1010, 3499, 3085, 6537, 2097, 2025, 2022, 4417, 2006, 1996, 18929, 1012, 1000, 22563, 1000, 10060, 2097, 2025, 2022, 3024, 2144, 2009, 1005, 2222, 4254, 2006, 2256, 4722, 5814, 2832, 1012, 1996, 1000, 22563, 1000, 10060, 2064, 2022, 3024, 1999, 2019, 3176, 5127, 1010, 2104, 5227, 1012, 1996, 2598, 2075, 1997, 1996, 5013, 2097, 2025, 2022, 23290, 2692, 1012, 1996, 2598, 2075, 11320, 5620, 1997, 1996, 9693, 2024, 7218, 2005, 1024, 1996, 2598, 2075, 17703, 2097, 2025, 2022, 4417, 2007, 2019, 1000, 1041, 1000, 1012, 1000, 22563, 1000, 10060, 2097, 2025, 2022, 3024, 2144, 2009, 1005, 2222, 4254, 2006, 2256, 4722, 5814, 2832, 1012, 1996, 1000, 22563, 1000, 10060, 2064, 2022, 3024, 1999, 2019, 3176, 5127, 1012, 2203, 14257, 2064, 2022, 6727, 2011, 2057, 2290, 2104, 5227, 1012, 2174, 1010, 2009, 1005, 2222, 2025, 2022, 4417, 2006, 1996, 9093, 2203, 1012, 2045, 2003, 1037, 5013, 2007, 6758, 2373, 1997, 5539, 2243, 2860, 2108, 2881, 2004, 2659, 10004, 2004, 2566, 12407, 2013, 8013, 1012, 1996, 9693, 2024, 8790, 3973, 12042, 2007, 2431, 1011, 3145, 1010, 2174, 1996, 2828, 1997, 20120, 2097, 2025, 2022, 5393, 1999, 1996, 1996, 9093, 1011, 2203, 1012, 2019, 3176, 5127, 2064, 2022, 8127, 2007, 2023, 2592, 1012, 2057, 2290, 3640, 1037, 2236, 7565, 5059, 2029, 3065, 1996, 9093, 1011, 2203, 1997, 1996, 5013, 1012, 2107, 5059, 2003, 8127, 10329, 2013, 1996, 22834, 2213, 1997, 1996, 5013, 1012, 2079, 2140, 9693, 2024, 4654, 1011, 14925, 2004, 2566, 12407, 2013, 8013, 1999, 21792, 4160, 1012, 2566, 26770, 6537, 1997, 18929, 2929, 2097, 2025, 2022, 4417, 1999, 1996, 9093, 2203, 1012, 10060, 1000, 22563, 1000, 2003, 2025, 2005, 19763, 2078, 1999, 2057, 2290, 9922, 2832, 1012, 2045, 2003, 2053, 7635, 2005, 5830, 8720, 1012, 1996, 15196, 2024, 4417, 1012, 1996, 18929, 2015, 2024, 8790, 3973, 12042, 2007, 2431, 3145, 1012, 2174, 1010, 2045, 2003, 2053, 4568, 10060, 1999, 1996, 9093, 1011, 2203, 1012, 2035, 18929, 2015, 2097, 2022, 12042, 2007, 2431, 3145, 1012, 6854, 2057, 2123, 2102, 2031, 4713, 2832, 2000, 2928, 1996, 9093, 2007, 1000, 22563, 1000, 12407, 1012, 3531, 5136, 1996, 11343, 1997, 4374, 2023, 13964, 2007, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'feature_index': 0, 'qas_id': '1dc7819e3b1611eb9c777c8ae1da80a8'}, {'start_positions': 119, 'end_positions': 150, 'cls_index': 0, 'p_mask': [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 'is_impossible': False}).


         [[{{node PyFunc}}]]
         [[MultiDeviceIteratorGetNextFromShard]]
         [[RemoteCall]]

Tensorflow: 2.4.1
Transformers: 4.3.0.dev0
Python: 3.8.5

Thank you in advance!