Token level representations


I’d like to get token level representation in output of my encoder.

I know how to do it using MLM based encoder but for my specific use case I only want token level embeddings, i.e hidden states and not MLM logits anymore.

Is there any dedicated module to perform this?

Any help is welcome, thank you.

from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM(...)

output = model(**encoded_input)

# Hidden state
output.hidden_states[-1] # I need this

# MLM logits
output.logits # I don't need this anymore and I want to avoid this computation to save time.