What should I change in below snippet to get consistent and accurate output?
from transformers import ElectraTokenizer, TFElectraForQuestionAnswering, ElectraConfig
import tensorflow as tf
configuration = ElectraConfig()
tokenizer = ElectraTokenizer.from_pretrained(‘google/electra-small-discriminator’)
TFElect = TFElectraForQuestionAnswering(configuration)
#model = TFElectraForQuestionAnswering.from_pretrained(‘google/electra-small-discriminator’)
model = TFElect.from_pretrained(‘google/electra-small-discriminator’)
question, text = “Who was Jim Henson?”, “Jim Henson was a nice puppet”
input_dict = tokenizer(question, text, return_tensors=‘tf’)
outputs = model(input_dict,return_dict=True)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_tokens = tokenizer.convert_ids_to_tokens(input_dict[“input_ids”].numpy())
answer = ’ '.join(all_tokens[tf.math.argmax(start_logits, 1) : tf.math.argmax(end_logits, 1)+1])
Output: I get different and incorrect output every time I run it so it seems it doesn’t have any pre-trained weights for the QnA tasks [Also getting warning as below].
Some layers from the model checkpoint at google/electra-small-discriminator were not used when initializing TFElectraForQuestionAnswering: [‘discriminator_predictions’]
- This IS expected if you are initializing TFElectraForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFElectraForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFElectraForQuestionAnswering were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: [‘qa_outputs’]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.