Pruning a model embedding matrix for memory efficiency

Hi, I’m trying to finetune the facebook/mbart-large-50-many-to-many-mmt model for machine translation. Unfortunately, I keep maxing out my GPU memory and even with a batch size of 1 sample with gradient accumulation I cannot get it to work.

I was looking through potential solutions and came across this thread where pruning the embeddings has been suggested as a solution. @sshleifer created an issue for the same here and here, but I don’t think it saw any progress.

I’m trying to do this by myself right now, and was wondering if my approach was correct -

  1. Run tokenizer on dataset and get a vocabulary of all unique tokens
  2. Copy all the embeddings associated with the vocabulary and create a new embedding matrix
  3. Replace the embedding matrix in the model with the new one
  4. Map the old vocabulary to their corresponding indices on the new embedding matrix
  5. Run tokenizer again but remap tokens to new embedding matrix before passing them to the model

Does anyone here have any idea if this could work?

Yes this seems like the right approach.
When you get to step 4/5 you can just make a new Tokenizer.
If you get it working please post the solution here!

1 Like

So, I had success in getting this to work! I was able to prune the embedding matrix and lm heads to less than a tenth of their sizes. On testing a couple of samples in Hindi to English translation I saw no difference in the translations between the stock model’s inference and the pruned model’s inference.

Btw, I’m stuck at step 4 where I need to make a new Tokenizer for the vocabulary (subword token to index mapping) that I’ve generated. I’m trying to use MBart50TokenizerFast for the same, and currently using dictionaries to map old indices to new indices. I’d really appreciate if you could point me in the right direction.

1 Like

I mentioned before that I got it to work, but it seems that while inference works perfectly, training doesn’t. When I try to train the model I get completely garbage results and I can’t really tell why. The training loss and validation loss are extremely small (around 2e-3) but the ROUGE scores I’m calculating are also abysmal (approx 2e-4). Lastly, training the model for one epoch makes it completely forget how to translate between the two languages I have it pruned for.

Okay, so I’ve worked everything out but the tokenizer. The model can be pruned and trained to perform quite well. Like I said above, I was getting extremely bad results, but it turns out that was due to my learning rate of 1e-5 being too high. I finally settled on a learning rate of 1e-8, and the model now actually converges. I feel that adding an lr scheduler with warmup, like on the fairseq models will resolve this issue, but I’m not sure how to do that with the Seq2SeqTrainer yet.

I still don’t know how to create a new tokenizer, but for the time being I’ve just defined a custom tokenizer that inherits from the main MBart50TokenizerFast class and adds a three functions - one to add the mapping of the old dictionary to the pruned dictionary, and two to encode and decode using this new dictionary. This may not be the “correct” way (by producing a new sentencepiece model), but works well enough in my opinion. I’m trying to figure that out but I have been unable to yet.

I would like to upload the pruned and finetuned model to the Model hub, but I’m unsure how that can be done without making a new sentencepiece model.

Hi Aditya Srivastava,

Could you share your code for pruning the embedding matrix and lm heads?

The weights of the input embedding and lm head seem to be shared. I don’t know what’s the correct way to changing the weights while keeping this constraint.

import torch
from transformers import MT5ForConditionalGeneration

model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
old_embedding = model.get_input_embeddings()
# ...select embeddings for some tokens
new_embedding = torch.nn.Embedding.from_pretrained(torch.rand(1000, 768))
model.set_input_embeddings(new_embedding)

print(model.lm_head.state_dict()["weight"].shape)
# Expect: [1000, 768]  Actual: [250112, 768]

1 Like

@sshleifer Hi, its been a while. I actually managed to get everything working correctly, including the tokenizer. Seeing how many hits this post has gotten and how many people have reached out to me since, I recently converted my code into a Python library which is now hosted on PyPI and supports both BART and T5.

Link to package. You can use the library to trim a model and its tokenizer to your data and then save both as new models. These models can then be reloaded the like native HuggingFace models for use again.

@Bookworm hope this helps you too.

1 Like

That’s pretty cool! Looking forward to benchmark results @IamAdiSri !

1 Like