Convert TAPAS tf checkpoint to PyTorch

Hello,

There’s a recent paper by Google Research called Tapas. It is basically a BERT model that is pretrained on masked language modeling (MLM) to answer questions about tables. Their idea was to just collect a lot of text snippets + corresponding tables from Wikipedia, and then use this data for masked language modeling. The tables are flattened in order to serve them as input to BERT. The input typically looks like [CLS] text snippet [SEP] flattened table.

The pretrained models they released are all made using Tensorflow. I’d like to convert them to PyTorch models using the script of the Tranformers library as explained in the docs. However, the model has a couple of additional embedding layers added to BERT, to learn about the structure of tables. I printed out the first tf variables of the checkpoint:

[('bert/embeddings/LayerNorm/beta', [768]),
('bert/embeddings/LayerNorm/beta/adam_m', [768]), 
('bert/embeddings/LayerNorm/beta/adam_v', [768]),
('bert/embeddings/LayerNorm/gamma', [768]),
('bert/embeddings/LayerNorm/gamma/adam_m', [768]),
('bert/embeddings/LayerNorm/gamma/adam_v', [768]),
('bert/embeddings/position_embeddings', [1024, 768]),
('bert/embeddings/position_embeddings/adam_m', [1024, 768]),
('bert/embeddings/position_embeddings/adam_v', [1024, 768]),
('bert/embeddings/token_type_embeddings_0', [3, 768]),
('bert/embeddings/token_type_embeddings_0/adam_m', [3, 768]),
('bert/embeddings/token_type_embeddings_0/adam_v', [3, 768]),
('bert/embeddings/token_type_embeddings_1', [256, 768]),
('bert/embeddings/token_type_embeddings_1/adam_m', [256, 768]),
('bert/embeddings/token_type_embeddings_1/adam_v', [256, 768]),
('bert/embeddings/token_type_embeddings_2', [256, 768]),
('bert/embeddings/token_type_embeddings_2/adam_m', [256, 768]),
('bert/embeddings/token_type_embeddings_2/adam_v', [256, 768]),
('bert/embeddings/token_type_embeddings_3', [2, 768]),
('bert/embeddings/token_type_embeddings_3/adam_m', [2, 768]),
('bert/embeddings/token_type_embeddings_3/adam_v', [2, 768]),
('bert/embeddings/token_type_embeddings_4', [256, 768]),
('bert/embeddings/token_type_embeddings_4/adam_m', [256, 768]),
('bert/embeddings/token_type_embeddings_4/adam_v', [256, 768]),
('bert/embeddings/token_type_embeddings_5', [256, 768]),
('bert/embeddings/token_type_embeddings_5/adam_m', [256, 768]),
('bert/embeddings/token_type_embeddings_5/adam_v', [256, 768]),
('bert/embeddings/token_type_embeddings_6', [10, 768]),
('bert/embeddings/token_type_embeddings_6/adam_m', [10, 768]),
('bert/embeddings/token_type_embeddings_6/adam_v', [10, 768]),
('bert/embeddings/word_embeddings', [30522, 768]),
('bert/embeddings/word_embeddings/adam_m', [30522, 768]),
('bert/embeddings/word_embeddings/adam_v', [30522, 768]),
(...)

As one can see, there are 6 token type embedding layers rather than 1. How should I change the source code of modeling_bert in order to use the load_tf_weights_in_bert function?

2 Likes