How to load ckpt into my model base on tf2.x

I need to load ckpt file from google-search’s bert ckpt.

I have read all questions related to this and tried some methods but it seem still doesn’t work .

ckpt files are like this:

图片

I have tried method like:

bert_config = transformers.BertConfig.from_json_file('./bert/bert_config.json')
bert = transformers.TFBertModel.from_pretrained('./bert/bert_model.ckpt', config=bert_config)

I find a method said that use TFBertForPreTraining instead TFBertModel

bert_config = transformers.BertConfig.from_json_file('./bert/bert_config.json')
bert = transformers.TFBertModel.from_pretrained('./bert/bert_model.ckpt', config=bert_config)

but it still doesn’t work.

Also ,

 bert.load_weights('./bert/bert_model.ckpt')

doesn’t work too.

I really can’t understand why bert.load_weights('./bert/bert_model.ckpt') can’t work.Maybe rewrite it?

so how can a model base on tf2 can load a ckpt file?

hello?Can someone help me?

hey @Sniper when you say it “doesn’t work” what do you mean exactly? can you share a stack trace with the error message?

if i am not mistaken, you need to provide the full path to the checkpoint (including the .index suffix), so something like the following might work:

bert_config = transformers.BertConfig.from_json_file('./bert/bert_config.json')
bert = transformers.TFBertModel.from_pretrained('./bert/bert_model.ckpt.index', config=bert_config)

Hi,Thanks for your reply @lewtun

“It doesn’ work” means there can not load ckpt.

When I use model.load_weights(), there is different output of the model when I run the same code. I think the reason is random init instead of load ckpt.

When I use transformers to load ckpt(also add .index suffix), It show error message like this

2021-06-02 10:10:12.841564: I tensorflow/stream_executor/cuda/cuda_blas.cc:1838] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
Traceback (most recent call last):
  File "C:\anaconda\envs\tf2\lib\site-packages\transformers\modeling_tf_utils.py", line 1271, in from_pretrained
    missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
  File "C:\\anaconda\envs\tf2\lib\site-packages\transformers\modeling_tf_utils.py", line 467, in load_tf_weights
    with h5py.File(resolved_archive_file, "r") as f:
  File "C:\anaconda\envs\tf2\lib\site-packages\h5py\_hl\files.py", line 408, in __init__
swmr=swmr)
  File "C:\anaconda\envs\tf2\lib\site-packages\h5py\_hl\files.py", line 173, in make_fid
    fid = h5f.open(name, flags, fapl=fapl)
  File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py\h5f.pyx", line 88, in h5py.h5f.open
OSError: Unable to open file (file signature not found)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/", line 13, in <module>
    bert = transformers.TFBertModel.from_pretrained('./bert/bert_model.ckpt.index', config=bert_config)
  File "C:\anaconda\envs\tf2\lib\site-packages\transformers\modeling_tf_utils.py", line 1274, in from_pretrained
    "Unable to load weights from h5 file. "
OSError: Unable to load weights from h5 file. If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True.

code like this:

import tensorflow as tf
import transformers

for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)

tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
bert_config = transformers.BertConfig.from_json_file('./bert/bert_config.json')
bert = transformers.TFBertModel.from_pretrained('./bert/bert_model.ckpt.index', config=bert_config)

inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
outputs = bert(inputs)

last_hidden_states = outputs.last_hidden_state
print(last_hidden_states)

ah ok now i understand the problem: all the TFBertXxx.from_pretrained methods are based on keras and thus look for checkpoints with a .h5 extension.

i think we can try using the load_tf_weights function instead (link):

from transformers import BertTokenizer, BertConfig, TFBertModel, load_tf_weights_in_bert

# load config
config = BertConfig.from_json_file("./bert_config.json")
# instantiate model from config
model = TFBertModel(config)
# load pretrained weights
load_tf_weights_in_bert(model, config, "./bert_model.ckpt.index")
# instantiate tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
# returns tensor of shape [1, 8, hidden_dim]
model(**inputs).last_hidden_state

out of curiosity, why do you want to do things this way? is there a checkpoint missing on the Hub that you need?

So for the method load_tf_weights_in_bert,the ckpt is still read using pytorch.
Does it mean I can’t load ckpt file in a enviement that just have tensorflow?
That’s really complex.

The framework like keras-bertkeras4bert both can load ckpt. Can transformers add this method in a new version?

Because for those of us who are not native English speakers,we need use some like bert-chinese-tiny , CHINESE ALBERT(pretrain by Chinese team) and even the models from google-search’s github. They can download from Internet easily.

But it same like load them into model is a really difficult thing when use the transformers

Hi! I’m the TF maintainer at Hugging Face. I think what you want should be possible, and I agree that it’s annoying that you need PyTorch to use load_tf_weights_in_bert.

I’ve got a deadline to complete right now, but I’m going to try to look at this early next week (Monday/Tuesday). Hopefully I can get the load to work, and then I’ll share how I did it, and possibly update our load_tf_weights_in_bert function if necessary.

2 Likes

Hi @Rocketknight1
Hearty thanks for your kind help.

Cause my poor English,I still have a question.
When you finished the load_tf_weights_in_bert it will be show in this topic or it will be upload to github?
Can I get the code when you finish it?

@Sniper Unfortunately, I couldn’t find a simple solution - probably the best thing to do is to just install PyTorch (cpu-only is fine) to convert the checkpoints. I’m going to flag this issue for development so that we don’t depend on PyTorch for this in future.

ok,I got it .
Thank you for your patience.