I got perceiver model from hugging face defined as below:
...
config = PerceiverConfig(d_model=self._token_size, num_labels=self._num_labels)
decoder = PerceiverClassificationDecoder(
config,
num_channels=config.d_latents,
trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
use_query_residual=True,
)
return PerceiverModel(config, decoder=decoder)
token_size = 800
num_labels = 7
as a input I pass tensor in shape [batch_size, 32,800]
and labels as tensor in shape [batch_size, 7]
I made training loop as follow:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=envi_builder.config.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=envi_builder.config.step_size, gamma=0.5)
model.train()
for epoch in range(envi_builder.config.n_epochs):
loop = tqdm(dataloader_train, leave=True)
for (inputs, labels) in loop:
optimizer.zero_grad()
inputs = inputs.to(envi_builder.config.device)
labels = labels.to(envi_builder.config.device)
outputs = model(inputs=inputs)
logits = outputs.logits
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
for loss/logits attr I see that there is:
grad_fn = None
requires_grad = False
from transformers import PerceiverConfig, PerceiverModel
from transformers.models.perceiver.modeling_perceiver import (
PerceiverClassificationDecoder,
)
seems that something wrongs going on inside model forward () in part
sequence_output = encoder_outputs[0]
logits = None
if self.decoder:
if subsampled_output_points is not None:
output_modality_sizes = {
"audio": subsampled_output_points["audio"].shape[0],
"image": subsampled_output_points["image"].shape[0],
"label": 1,
}
else:
output_modality_sizes = modality_sizes
decoder_query = self.decoder.decoder_query(
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
)
decoder_outputs = self.decoder(
decoder_query,
z=sequence_output,
query_mask=extended_attention_mask,
output_attentions=output_attentions,
)
logits = decoder_outputs.logits. ---> No grad, No grad_fn
any idea ?