TrOCR inference

@nielsr, i am using trocr-printed for inference using the below code and its working fine except for
the len of (tuple_of_logits) , it’s always 19, no matter what batch_size i use, even when i override the
model.decoder.config.max_length from 20 to 10, the len(tuple_of_logits) is always 19.

can you please help me figure out, what am I missing here?

for batch in tqdm(test_dataloader):
    # predict using generate
    pixel_values = batch["pixel_values"].to(device)
    outputs = model.generate(pixel_values, output_scores=True, return_dict_in_generate=True)
    tuple_of_logits = outputs.scores
    print(len(tuple_of_logits))

Hi,

You can adjust the maximum number of tokens by specifying the max_length parameter of the generate method:

outputs = model.generate(pixel_values, output_scores=True, return_dict_in_generate=True, max_length=10)

Note that this is explained in the docs.

1 Like