Perceiver io : Is there any way to specify the query tensor

I need to implement something similar to this photo. The perceiver IO will take :

  1. C_t
  2. Q_t ( Just for the Decoding Phase)

I have looked at the basic decoder API. But there is no way to specify the query vector.

This code is from another pytorch implementation :

from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 32,                    # dimension of sequence to be encoded
    queries_dim = 32,            # dimension of decoder queries
    logits_dim = 100,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
    seq_dropout_prob = 0.2       # fraction of the tokens from the input sequence to dropout (structured dropout, for saving compute and regularizing effects)

seq = torch.randn(1, 512, 32)
queries = torch.randn(128, 32)

logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)

Here I can specify these two inputs to perceiver io in a straightforward manner. Is it possible to do using Huggingface’s perceiver io API ?


You can specify the query tensor if you prefer so, by implementing a decoder yourself. See here for how a decoder should be implemented. One can implement the decoder_query method which defines the queries that are fed to the decoder.

The PerceiverModel in the Transformers library takes in a decoder, hence you can then create the model as follows:

from transformers import PerceiverModel

model = PerceiverModel(decoder=your_decoder)