I’m trying to execute some text classification, and was initially successful getting the run_text_classification.py script from the examples to predict a sample test using a model I pretrained (using the same script). This initial success was run on an interactive node on a shared compute server. The code I used was:
python run_text_classification.py --model_name_or_path /BERT-OUTPUT/COV-CT --output_dir /BERT-OUTPUT/Prediction/COV-CT/ --test_file test-tweets-to-predict.csv
However, when attempting to run the same command on a larger set of inputs (submitted as a batch job), I encountered the following error:
Traceback (most recent call last): File "/scratch/st-tlemieux-1/lfrymire/ML_scripts/run_text_classification.py", line 531, in <module> main() File "/scratch/st-tlemieux-1/lfrymire/ML_scripts/run_text_classification.py", line 506, in main predictions = model.predict(tf_data["test"])["logits"] File "/home/lfrymire/.local/lib/python3.8/site-packages/transformers/file_utils.py", line 1887, in __getitem__ return inner_dict[k] KeyError: 'logits'
When I went back to test and debug on the interactive node, I was able to run it without issue several times, but strangely I now get the same error on both nodes. This is now the case for both the small sample text and the larger set I’m trying to actually predict. In all cases I am running the same command. Any help is greatly appreciated!
- `transformers` version: 4.9.0 - Platform: Linux-3.10.0-1160.24.1.el7.x86_64-x86_64-with-glibc2.2.5 - Python version: 3.8.3 - PyTorch version (GPU?): not installed (NA) - Tensorflow version (GPU?): 2.4.0 (True) - Flax version (CPU?/GPU?/TPU?): not installed (NA) - Jax version: not installed - JaxLib version: not installed - Using GPU in script?: Yes - Using distributed or parallel set-up in script?: No