PerceiverModel training logits does not require grad and does not have a grad_fn

I got perceiver model from hugging face defined as below:

  ...
  config = PerceiverConfig(d_model=self._token_size, num_labels=self._num_labels)
  decoder = PerceiverClassificationDecoder(
      config,
      num_channels=config.d_latents,
      trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
      use_query_residual=True,
  )
  return PerceiverModel(config, decoder=decoder)

token_size = 800
num_labels = 7

as a input I pass tensor in shape [batch_size, 32,800]
and labels as tensor in shape [batch_size, 7]

I made training loop as follow:

criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=envi_builder.config.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=envi_builder.config.step_size, gamma=0.5)

model.train()
for epoch in range(envi_builder.config.n_epochs):
    loop = tqdm(dataloader_train, leave=True)
    for (inputs, labels) in loop:
        optimizer.zero_grad()

        inputs = inputs.to(envi_builder.config.device)
        labels = labels.to(envi_builder.config.device)

        outputs = model(inputs=inputs)
        logits = outputs.logits

        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

for loss/logits attr I see that there is:

grad_fn = None
requires_grad = False
from transformers import PerceiverConfig, PerceiverModel
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverClassificationDecoder,
)

seems that something wrongs going on inside model forward () in part

        sequence_output = encoder_outputs[0]

        logits = None
        if self.decoder:
            if subsampled_output_points is not None:
                output_modality_sizes = {
                    "audio": subsampled_output_points["audio"].shape[0],
                    "image": subsampled_output_points["image"].shape[0],
                    "label": 1,
                }
            else:
                output_modality_sizes = modality_sizes
            decoder_query = self.decoder.decoder_query(
                inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
            )
            decoder_outputs = self.decoder(
                decoder_query,
                z=sequence_output,
                query_mask=extended_attention_mask,
                output_attentions=output_attentions,
            )
            logits = decoder_outputs.logits.    ---> No grad, No grad_fn

any idea ?