Use one-hot encoding as input for T5 and GPT


Is it possible to train the T5 model by using onehot encoding input and integers as target?

something like that loss = model(onehot, attn, targetInList , attn)
As it’s a translation problem, the onehot input will be [ [0,0,0,1…],[0,1,0,0,…]…]

Any help would be appretiated!

If cannot, is there any other way to convert that one-hot encoding input to normal integer list? because torch.argmax is not differentiable