Asymmetry in validation step vs. autoregressive inference

I’m training a Transformer model (BartForConditionalGeneration) to assign MIDI notes to guitar strings as a larger project on automatic guitar tablature generation. The training data are MIDI files with 6 tracks, one for each string. Each note appears on the track corresponding to the string it is played on. My tokenizer creates four tokens per note:

..., TimeShift, StringNumber, Pitch, Duration, ...

I formulate the problem as a form of masked language modeling. At training time, we mask out the StringNumber in the input (so the 2nd token in each block of 4).

..., TimeShift, MASK, Pitch, Duration, ...

The targets/labels ignore everything except the StringNumber

..., -100, StringNumber, -100, -100, ...

And the decoder_input_ids are the shifted ground truth

..., EOS, TimeShift, StringNumber, Pitch, Duration, ...

Here’s what I do in the training forward pass:

                return self.model(
                    masked_input_sequence,
                    labels=labels,
                    decoder_input_ids=self.shift_tokens_right(input_sequence),
                    encoder_outputs=encoder_outputs,
                )

I’ve got 27K guitar tabs from a large dataset. In the training, I’m seeing convergence to 98% accuracy on the validation set. See below for plots.

However, the model seems to do significantly worse when I am trying to have it do autoregressive inference without ground truth decoder input IDs. I’ve heard this issue described as “exposure bias”.

At inference, this is what I do:

            # Autoregressive generation
            encoder_last_hidden_state = None
            for i in range(0, masked_input_sequence.shape[-1], TOKENS_PER_NOTE):
                # Get the model's prediction for the current sequence
                with torch.no_grad():
                    outputs = self(
                        masked_input_sequence,
                        encoder_last_hidden_state=encoder_last_hidden_state,
                    )

                encoder_last_hidden_state = outputs.encoder_last_hidden_state

                # Extract the prediction for the masked Program token
                program_token_pred = outputs.logits[:, i + 1, :].argmax(dim=-1)
                string_predictions[:, i // TOKENS_PER_NOTE] = program_token_pred

                # Update the input tensor with the predicted Program token
                masked_input_sequence[:, i + 1] = program_token_pred

With this, the accuracy drops from 98% down to 85%. I suppose this could be explained simply as compounding errors, but the network seems to make a lot of “dumb mistakes”, like assigning the same string to co-occurring notes (see below).

In summary, the results from auto-regressive inference are significantly worse than the accuracy would lead me to believe, and I’m hoping the cause is something stupid I am doing. I’m thinking perhaps a beam search might help.

If nothing seems terribly wrong with my technical approach, I think I will have to consider representing the problem differently / going back to the drawing board.

Thanks!