Recovering token ids from normalized input?

Trying to figure out conceptually what is wrong here. I have a flow that does the following:

Text → Produce Token Ids → Normalize Ids → AutoEncoder → Calculate CosineEmbeddingLoss.

This process seems to work and ultimately completes the task but I cannot reproduce any of the inputs as the token ids are normalized so tokenizer.decode() does not work. Is there a better way to do this?

Relevant code:

class AE(nn.Module): 
  def __init__(self):
    self.encoder = torch.nn.Sequential(
      torch.nn.Linear(512, 512), # Input is in the format (Batchx512) 
      torch.nn.Linear(512, 256),
    self.decoder = torch.nn.Sequential(
      torch.nn.Linear(256, 512),
      torch.nn.Linear(512, 512),

  def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)
    return x

And training

  def training_step(self, batch, batch_idx):
    x = batch
    x_hat =
    loss_fn = nn.CosineEmbeddingLoss()
    loss = loss_fn(x_hat, x, torch.Tensor([1.]))
    return loss

I was thinking to do F.normalize in the encoder but again I am not sure how to undo that transform witht he decoder or how I would emit outputs. Or do I need to swap nn.Sigmoid with nn.ReLU? (Seems CosineSim is scaling sensitive, so not sure if I’d need to swap my loss)