KeyError: 'logits' while predicting with run_text_classification example

I’m trying to execute some text classification, and was initially successful getting the script from the examples to predict a sample test using a model I pretrained (using the same script). This initial success was run on an interactive node on a shared compute server. The code I used was:

python --model_name_or_path /BERT-OUTPUT/COV-CT --output_dir /BERT-OUTPUT/Prediction/COV-CT/ --test_file test-tweets-to-predict.csv

However, when attempting to run the same command on a larger set of inputs (submitted as a batch job), I encountered the following error:

Traceback (most recent call last):
  File "/scratch/st-tlemieux-1/lfrymire/ML_scripts/", line 531, in <module>
  File "/scratch/st-tlemieux-1/lfrymire/ML_scripts/", line 506, in main
    predictions = model.predict(tf_data["test"])["logits"]
  File "/home/lfrymire/.local/lib/python3.8/site-packages/transformers/", line 1887, in __getitem__
    return inner_dict[k]
KeyError: 'logits'

When I went back to test and debug on the interactive node, I was able to run it without issue several times, but strangely I now get the same error on both nodes. This is now the case for both the small sample text and the larger set I’m trying to actually predict. In all cases I am running the same command. Any help is greatly appreciated!

Environment information:

- `transformers` version: 4.9.0

- Platform: Linux-3.10.0-1160.24.1.el7.x86_64-x86_64-with-glibc2.2.5

- Python version: 3.8.3

- PyTorch version (GPU?): not installed (NA)

- Tensorflow version (GPU?): 2.4.0 (True)

- Flax version (CPU?/GPU?/TPU?): not installed (NA)

- Jax version: not installed

- JaxLib version: not installed

- Using GPU in script?: Yes

- Using distributed or parallel set-up in script?: No