are there any suggestions for this? I am not sure if hidden[:, 0, :]
makes sense (since no [CLS]
token in T5) but I found that using hidden[:,0,:]
is yielding better results than torch.mean(hidden_states, dim=1)
. Any suggestions on whats the best way to do this in T5Encoder?