mBART embedding matrix prunning

I hope this posts is okay in the forums.

I am using an mBART model and I want to reduce the size of the vocabulary. To do so:
(1) I mapped the old vocabulary to the new one.
(2) I got the original embeddings and created the new embedding matrix following my new vocabulary mapping.
(3) I put the new embedding matrix into the mBART model.

Specifically, for (3) I did:

#ft_tensor contains the input embedding matrix
ft_tensor = torch.load(args.pruned_model + ‘/input_embedding_weights.pt’)
new_wte = Embedding.from_pretrained(ft_tensor)
new_wte.padding_idx = 0
model.set_input_embeddings(new_wte)
# LM head
model.lm_head.weight = model.get_input_embeddings().weight
# The procedure is repeated for the output embeddings
ft_tensor = torch.load(args.pruned_model + ‘/output_embedding_weights.pt’)
new_wte = Embedding.from_pretrained(ft_tensor)
new_wte.padding_idx = 0
model.set_output_embeddings(new_wte)

I also adapted the Tokenizer to the new vocabulary. When running this (in the forward pass) I get the following error:

Traceback (most recent call last):
File “main.py”, line 71, in
main()
File “main.py”, line 49, in main
outputs = model(**batch)
File “/var/python3envs/transformers-4.5.1/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/var/python3envs/transformers-4.5.1/lib/python3.6/site-packages/transformers/models/mbart/modeling_mbart.py”, line 1305, in forward
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
File “/var/python3envs/transformers-4.5.1/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/var/python3envs/transformers-4.5.1/lib/python3.6/site-packages/torch/nn/modules/sparse.py”, line 126, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File “/var/python3envs/transformers-4.5.1/lib/python3.6/site-packages/torch/nn/functional.py”, line 1852, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 ‘indices’ to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)

The first thing I did was casting the input to Long, but this did not solve the error. In fact, if I do not modify the embeddings of the model (keep the original model), the code works well. If I get the original embeddings (model.get_input_embeddings().weight) and copy them again (model.set_input_embeddings(…)) the same error arises. I hope someone can shed light on this issue.