Hi Aditya Srivastava,
Could you share your code for pruning the embedding matrix and lm heads?
The weights of the input embedding and lm head seem to be shared. I don’t know what’s the correct way to changing the weights while keeping this constraint.
import torch
from transformers import MT5ForConditionalGeneration
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
old_embedding = model.get_input_embeddings()
# ...select embeddings for some tokens
new_embedding = torch.nn.Embedding.from_pretrained(torch.rand(1000, 768))
model.set_input_embeddings(new_embedding)
print(model.lm_head.state_dict()["weight"].shape)
# Expect: [1000, 768] Actual: [250112, 768]