I am trying to utilize the decoder model of the BART transformer in a non-casual mode. This is straightforward in JAX by just changing the self.causal
flag to True
. While the PyTorch implementation also has the is_decoder
flag, the decoder flag doesn’t enforce the non casual or casual mode. Is there a way to use the decoder in a non-casual mode with PyTorch?