Pruning a model embedding matrix for memory efficiency

Hi Aditya Srivastava,

Could you share your code for pruning the embedding matrix and lm heads?

The weights of the input embedding and lm head seem to be shared. I don’t know what’s the correct way to changing the weights while keeping this constraint.

import torch
from transformers import MT5ForConditionalGeneration

model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
old_embedding = model.get_input_embeddings()
# ...select embeddings for some tokens
new_embedding = torch.nn.Embedding.from_pretrained(torch.rand(1000, 768))
model.set_input_embeddings(new_embedding)

print(model.lm_head.state_dict()["weight"].shape)
# Expect: [1000, 768]  Actual: [250112, 768]

1 Like