How to instantiate Bart Decoder in a non causal way - PyTorch

I am trying to utilize the decoder model of the BART transformer in a non-casual mode. This is straightforward in JAX by just changing the self.causal flag to True. While the PyTorch implementation also has the is_decoder flag, the decoder flag doesn’t enforce the non casual or casual mode. Is there a way to use the decoder in a non-casual mode with PyTorch?