Constrain output format from beam search in Donut doc classification

Hello,

I’m trying to do active learning with Donut’s document classification ("naver-clova-ix/donut-base-finetuned-rvlcdip"). The usual generated text is on the form: '<s_rvlcdip><s_class><invoice/></s_class>', where <s_rvlcdip> is the prompt and <invoice/> is one of 16 class tokens. I want to compute the entropy of the categorical distribution, and use it to select informative documents for labeling.

To that end, I’m trying to use beam search to give me 16 beams, which hopefully correspond to one class each, and the sequence scores would be log-probabilities of each class given the document. However, the generated sequences are not always on the desired form, so I’m trying to force the beam search to always use both the <s_class> and </s_class> tokens, and to use at least one of the 16 class tokens; at the same time I’m setting min_length = max_length = 4, and no_repeat_ngram_size=1. This still leaves the order of the tokens unconstrained, but at present the problem is that the constraints are not being obeyed.

Here’s some of the code:

# prepare decoder inputs
task_prompt = "<s_rvlcdip>"
decoder_input_ids = processor.tokenizer(
    task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = processor(sample, return_tensors="pt").pixel_values
# force beam search to use one of the class labels at least once
# force it to include end class tokens
# by setting min_length = max_length = 4, we should get the output to always be on the form
# <s_rvlcdip><s_class><class_label/></s_class>
force_class_start = ["<s_class>"]
force_class_end = ["</s_class>"]
force_one_of = ["<invoice/>", "<budget/>", "<news_article/>", "<specification/>", "<scientific_report/>", "<scientific_publication/>", "<questionnaire/>", "<letter/>", "<advertisement/>", "<form/>", "<handwritten/>", "<file_folder/>", "<email/>", "<memo/>", "<resume/>", "<presentation/>"]
force_words_ids = [
    *processor.tokenizer(force_class_start, add_special_tokens=False).input_ids,
    *processor.tokenizer(force_class_end, add_special_tokens=False).input_ids,
    processor.tokenizer(force_one_of, add_special_tokens=False).input_ids
    ]
outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    force_words_ids=force_words_ids,
    max_length=4,
    min_length=4,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=16,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    num_return_sequences = 16,
    return_dict_in_generate=True,
    output_scores=True,
    no_repeat_ngram_size = 1, # prevents repetition of tokens like <s_class><s_class>
)

and here’s an example output:

print(processor.batch_decode(outputs.sequences))

['<s_rvlcdip><s_class><invoice/></s_class>',
 '<s_rvlcdip><s_class><budget/></s_class>',
 '<s_rvlcdip><s_class><news_article/></s_class>',
 '<s_rvlcdip><s_class><specification/></s_class>',
 '<s_rvlcdip><s_class> අරන්</s_class>',
 '<s_rvlcdip><s_class> 같습니다</s_class>',
 '<s_rvlcdip><s_class> منهن</s_class>',
 '<s_rvlcdip><s_class> අරන් красива',
 '<s_rvlcdip><s_class> අරන් ракета',
 '<s_rvlcdip><s_class> අරන් находиться',
 '<s_rvlcdip> кетсе<specification/><s_class>',
 '<s_rvlcdip> ඔලුව<advertisement/><s_class>',
 '<s_rvlcdip> кетсе<advertisement/><s_class>',
 '<s_rvlcdip></s_class><scientific_publication/><s_class>',
 '<s_rvlcdip></s_class><scientific_report/><s_class>',
 '<s_rvlcdip></s_class><email/><s_class>']

Here’s my full code: https://pastebin.com/pgHMKFmK
When debugging, I can see the force_words_ids are correctly translated to two PhrasalConstraints and a DisjunctiveConstraint, but then either the ConstrainedBeamSearchScorer or constrained_beam_search have a bug, or I’m doing something wrong.

Is there a better way to compute a consistent probability distribution, or a way to fix the approach here?

I don’t think this will be the case. Wouldn’t it be easier to just check the logits of the token that the model predicts after the prompt? This will be a distribution over all possible tokens of the vocabulary of Donut’s decoder, hence you can perhaps only look at the scores of the class tokens and normalize them to get a categorical distribution.

Wouldn’t it be easier to just check the logits of the token that the model predicts after the prompt?

That would be great! I think the model is trained to output the <s_class> token immediately after the prompt, but is there an easy way to get the logits of the next token conditioned on <s_class> being the first? Or should I just add <s_class> to the task prompt?

Oh yeah sorry, indeed. I’d check the probability distribution after the “<s_class>” token (or whatever you’ve decided to make the model generate before the first class token - the “<s_rvlcdip><s_class>” prompt was just a choice of the Donut authors)

You can provide output_scores=True to the generate method.

I see, I’ve tried running with min_length = max_length = 4, num_beams = num_return_sequences = 1 and no forced words, and outputs.scores now contains three tensors with shape (1, 57544). So the second tensor is a distribution over all tokens given the prompt and first token, from which I can pick the 16 class tokens and renormalize:

    class_tokens = [
        "<invoice/>",
        "<budget/>",
        "<news_article/>",
        "<specification/>",
        "<scientific_report/>",
        "<scientific_publication/>",
        "<questionnaire/>",
        "<letter/>",
        "<advertisement/>",
        "<form/>",
        "<handwritten/>",
        "<file_folder/>",
        "<email/>",
        "<memo/>",
        "<resume/>",
        "<presentation/>",
    ]
    class_token_ids = processor.tokenizer(
            class_tokens, add_special_tokens=False
        ).input_ids

    logit_p_2nd_token_given_prompt_and_1st_token = outputs.scores[1][0,:].cpu()
    logit_p_class = logit_p_2nd_token_given_prompt_and_1st_token[torch.tensor(class_token_ids).squeeze()]
    p_class = torch.softmax(logit_p_class, 0)
    entropy = -torch.log(p_class) @ p_class

?

(On another note, I think there still might be a problem with the constrained beam search, as I never got the output sequences to obey the constraints.)