Train new Word Embedding for mBART

TL;DR: I want to train a (set of) new word embedding(s) for mBART instead of training it for BERT—how do I do that?

Background:

I found an interesting code here: GitHub - tai314159/PWIBM-Putting-Words-in-Bert-s-Mouth: Putting Words in Bert's Mouth: Navigating Contextualized Vector Spaces with Pseudowords. This code uses example sentences to generate so called “pseudoword embeddings” in get_pseudowords.py. In this form, the code outputs one pseudoword embedding per example sentence (stored in data/queries/single_target/MaPP_all.txt).

After adjusting a few lines of the code, it worked for me and I could add the newly-generated embeddings to a vanilla BERT model. I also extended the code so that it only outputs one pseudoword per “group” of example sentences:

# [...]

NEW_TOKEN = '#TOKEN#'

# [...]

class Coercion:

# [...]

    def coercion(self,
                 group,
                 k: int = 5):
        model = BertForMaskedLM.from_pretrained(
            'bert-base-cased', return_dict=True)
        model.to('cuda')

        self.builder.tokenizer.add_tokens(NEW_TOKEN)
        model.resize_token_embeddings(len(self.builder.tokenizer))

        new_queries = []
        queries = []
        vec_targets = []

        # Print targets (and their id's) and the query (and its id)
        for entry in group:
            i = 0
            while True:
                i = i + 1
                if ('target' + str(i)) not in entry.keys():
                    break
                print('target ' + str(i) + ': ' + entry["target" + str(i)] + " , " + str(entry["target" + str(i) + "_idx"]))
            print('query:' + entry["query"] + " , " + str(entry["query_idx"]))

            # Model output
            nlp = FillMaskPipeline(model, self.builder.tokenizer, device=0)
            output = nlp(entry["query"])
            output = self._format(output)
            print('[MASK]=' + str(output))

            for j in range(1, i):
                vec_targets.append(
                    self._get_target_embed((entry["target" + str(j)], entry["target" + str(j) + "_idx"]), model))

            new_query = entry["query"].split()
            new_query[entry["query_idx"]] = NEW_TOKEN
            new_query = ' '.join(new_query)
            query = (new_query, entry["query_idx"])
            print(query)
            new_queries.append(new_query)
            queries.append(query)

        model = self._freeze(model)

        model.eval()

        for i in range(k):
            print('-' * 40)
            print('Random {a}'.format(a=i))

            # Random initialization, same initialization as huggingface
            weight = model.bert.embeddings.word_embeddings.weight.data[-1]
            nn.init.normal_(weight, mean=0.0,
                            std=model.config.initializer_range)

            # Before training
            print('Before training:')
            nlp = FillMaskPipeline(model, self.builder.tokenizer, device=0)

            model = self._train(model, vec_targets, queries)

            print("*************************************************************************")
            # After training
            print('After training:')
            nlp = FillMaskPipeline(model, self.builder.tokenizer, device=0)
            for new_query in set(new_queries):  # only view different queries
                print("query: " + new_query)
                output = nlp(new_query)
                output = self._format(output)
                print('[MASK]=' + str(output))

                outputs_list.append(output)

                output = self._predict_z(model, query)
                output = self._format(output)
                print(NEW_TOKEN + '=' + str(output))
            print("*************************************************************************")

    def _train(self, model, vec_targets, queries):
        loss_fct = nn.MSELoss(reduction='mean')  # mean will be computed later
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.3, eps=1e-8)
        epoch = 1000
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=epoch)

        max_length = 1 + max([len(self.builder.encode(query[0])[1]) for query in queries])  # possible padding
        input_ids_and_gather_indexes = [self.builder.encode(query[0], max_length=max_length) for query in queries]
        input_ids = torch.cat([input_id for input_id in [i for i, _ in input_ids_and_gather_indexes]], dim=0).to("cuda")
        gather_indexes = [gather_index for gather_index in [g for _, g in input_ids_and_gather_indexes]]

        # target_idx is the index of target word in the token list.
        target_idxs = [g[q[1] + 1][0] for g, q in zip(gather_indexes, queries)]
        target_idxs = torch.tensor(target_idxs, device="cuda").unsqueeze(-1)
        # token_idx is the index of target word in the vocabulary of BERT
        token_idxs = input_ids.gather(dim=-1, index=target_idxs)
        vocab_size = len(tokenizer.get_vocab())
        min_token_idx = min(token_idxs)
        indices = torch.tensor([i for i in range(vocab_size) if i < min_token_idx], device="cuda", dtype=torch.long)

        for _ in trange(epoch):
            model.zero_grad()
            outputs = model(input_ids, output_hidden_states=True)
            z = torch.index_select(outputs.hidden_states[12][0], dim=0, index=target_idxs.squeeze(-1))

            loss = loss_fct(z, torch.stack(vec_targets))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            model.bert.embeddings.word_embeddings.weight.grad[indices] = 0
            optimizer.step()
            scheduler.step()

        # get the z* for classification
        vec = model.bert.embeddings.word_embeddings(token_idxs).squeeze(1)[0]  # this is z*; [0] because all the same
        vec_array = vec.cpu().detach().numpy()
        z_list.append(vec_array)
        loss_list.append(str(loss.cpu().detach().numpy()))

        # save checkpoints
        try:
            np.save(CACHE + "temp_z_arrays.npy", np.array(z_list))
            np.save(CACHE + "temp_loss_arrays.npy", np.array(loss_list))
        except:
            print("Skip saving this time...")

        s = 'Final loss={a}'.format(a=str(loss.cpu().detach().numpy()))
        print(s)

        return model

    def _freeze(self, model):
        # Freeze all the parameters except the word embeddings
        for name, param in model.named_parameters():
            param.requires_grad = False
            if name == 'bert.embeddings.word_embeddings.weight':
                param.requires_grad = True

        # Manually break the connection of decoder and embeddings.
        original_weight = model.cls.predictions.decoder.weight
        original_bias = model.cls.predictions.decoder.bias
        decoder = nn.Linear(768, len(tokenizer) - 1, bias=True)
        decoder.weight.requires_grad = False
        decoder.bias.requires_grad = False
        decoder.weight.data.copy_(original_weight.data[:-1])
        decoder.bias.data.copy_(original_bias.data[:-1])
        model.cls.predictions.decoder = decoder

        return model

# [...]

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
builder = DataBuilder(tokenizer)
co = Coercion(builder)
for group in data:
    co.coercion(group)
    print('==' * 40)

result = get_lowest_loss_arrays(z_list, loss_list)

This also worked reasonably well, so I could append my new embeddings to a fresh model.

My Approach:

Now, I wanted to do something a bit different. I wanted to use the mBART-50 model (MBart and MBart-50) to do a similar thing, but masking phrases of variable lengths with <mask>. The new, temporary #TOKEN# embedding, will still only “mask” one word (like in the code above).

mBART has a few attributes, which are different though, so I am not sure which lines could stay the same and which ones need to change. This is my approach so far (lines I am especially unsure with, are marked with “??”):

import csv
import itertools
from typing import List, Tuple, TextIO

from transformers import AutoTokenizer, MBart50Tokenizer, MBartForConditionalGeneration, Text2TextGenerationPipeline
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
import torch
import torch.nn as nn
import torch.optim
from tqdm import trange, tqdm
import jsonlines
import json
import numpy as np
import os
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
from scipy import spatial
from sklearn.metrics import mean_squared_error
import pickle

NEW_TOKEN = '#TOKEN#'

Item = Tuple[str, int]
Example = Tuple[Item, Item]

# ARGS
QUERIES_PATH = "../../out/CoMaPP_all.json"  # path to queries
DATASET_PATH = "CoMaPP_Dataset.csv"
DIR_OUT = "../../out/"  # path to dir to save the pseudowords
CACHE = "../../out/cache/"  # path to cach directory

################################################

class DataBuilder:
    def __init__(self, tokenizer: AutoTokenizer):
        self.tokenizer = tokenizer

    def encode(self, text: str, max_length=None):
        tokens = text.split()
        # Build token indices
        _, gather_indexes = self._manual_tokenize(tokens)
        # Tokenization
        if max_length:
            encode_dict = self.tokenizer(
                text, return_attention_mask=True,
                return_token_type_ids=False, return_tensors='pt',
                padding='max_length', max_length=max_length)
        else:
            encode_dict = self.tokenizer(
                text, return_attention_mask=True,
                return_token_type_ids=False, return_tensors='pt')
        input_ids = encode_dict['input_ids']
        return input_ids, gather_indexes

    def _manual_tokenize(self, tokens: List[str]):
        split_tokens = []
        gather_indexes = []
        for token in tokens:
            indexs = []
            for sub_token in self.tokenizer.tokenize(token):
                indexs.append(len(split_tokens))
                split_tokens.append(sub_token)
            gather_indexes.append(indexs)

        gather_indexes = [(min(t), max(t) + 1) for t in gather_indexes]

        # Adjust for CLS and SEP
        indices = [(a + 1, b + 1) for a, b in gather_indexes]
        # Add of CLS and SEP
        indices = [(0, 1)] + indices + [(indices[-1][1], indices[-1][1] + 1)]
        return split_tokens, indices


class Coercion:
    def __init__(self, builder: DataBuilder):
        self.builder = builder

    def coercion(self,
                 group,
                 k: int = 5):
        model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50", return_dict=True)
        model.to('cuda')

        self.builder.tokenizer.add_tokens(NEW_TOKEN)
        model.resize_token_embeddings(len(self.builder.tokenizer))

        new_queries = []
        queries = []
        vec_targets = []

        # Print targets (and their id's) and the query (and its id)
        for entry in group:
            i = 0
            while True:
                i = i + 1
                if ('target' + str(i)) not in entry.keys():
                    break
                print(f'target {i}: {entry["target" + str(i)]}, {entry["target" + str(i) + "_idx"]}')
            print(f'query: {entry["query"]}, {entry["query_idx"]}')

            nlp = Text2TextGenerationPipeline(model=model, tokenizer=self.builder.tokenizer, device=0)
            output = nlp(entry["query"], max_length=30, num_return_sequences=5, num_beams=100)
            output = self._format(output)
            print(f"output: {output}")
            # print('<mask> = ' + str(outputs))  # TODO Just show the replaced <mask> token

            for j in range(1, i):
                vec_targets.append(
                    self._get_target_embed((entry["target" + str(j)], entry["target" + str(j) + "_idx"]), model)
                )

            new_query = entry["query"].split()
            new_query[entry["query_idx"]] = NEW_TOKEN
            new_query = ' '.join(new_query)
            query = (new_query, entry["query_idx"])
            print(query)
            new_queries.append(new_query)
            queries.append(query)

        model = self._freeze(model)

        model.eval()

        for i in range(k):
            print('-' * 40)
            print('Random {a}'.format(a=i))

            # Random initialization, same initialization as huggingface
            weight = model.get_input_embeddings().weight.data[-1]
            nn.init.normal_(weight, mean=0.0, std=model.config.init_std)

            # Before training
            # print('Before training:')
            # We need a Text2TextGeneration here, because mBart is created for translation, originally.
            # Only this way, there can be multiple predicted words for one <mask>.
            # nlp = Text2TextGenerationPipeline(model=model, tokenizer=self.builder.tokenizer, device=0)

            model = self._train(model, vec_targets, queries)

            print("*************************************************************************")
            # After training
            print('After training:')
            nlp = Text2TextGenerationPipeline(model=model, tokenizer=self.builder.tokenizer, device=0)
            for new_query in set(new_queries):  # only view different queries
                print(f"query: {new_query}")
                output = nlp(new_query, max_length=30, num_return_sequences=5, num_beams=100)  # TODO output looks fishy...
                output = self._format(output)
                print(f'output: {output}')

                outputs_list.append(output)

                output = self._predict_z(model, query)
                output = self._format(output)
                print(f'{NEW_TOKEN} {output}')
            print("*************************************************************************")

    def _train(self, model, vec_targets, queries):
        loss_fct = nn.MSELoss(reduction='mean')  # mean will be computed later
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.3, eps=1e-8)
        epoch = 1000
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=epoch)

        # This snippet, retrieving the possible padding, does the following:
        #  (a) encode each query's text (first [0]),
        #  (b) get the input_ids (second [0]),
        #  (c) count the input_ids (.shape[-1], because the number of input_ids is stored in the second/last dimension).
        # Then, you can take the max to know how much you should pad the rest.
        max_length = max([(self.builder.encode(query[0])[0]).shape[-1] for query in queries])

        input_ids_and_gather_indexes = [self.builder.encode(query[0], max_length=max_length) for query in queries]
        input_ids = torch.cat([input_id for input_id in [i for i, _ in input_ids_and_gather_indexes]], dim=0).to("cuda")
        gather_indexes = [gather_index for gather_index in [g for _, g in input_ids_and_gather_indexes]]

        # target_idx is the index of target word in the token list.
        target_idxs = [g[q[1] + 1][0] for g, q in zip(gather_indexes, queries)]
        target_idxs = torch.tensor(target_idxs, device="cuda").unsqueeze(-1)
        # token_idx is the index of target word in the vocabulary of BERT
        token_idxs = input_ids.gather(dim=-1, index=target_idxs)
        vocab_size = len(tokenizer.get_vocab())  # can be checked with tokenizer.get_added_vocab()
        min_token_idx = min(token_idxs)
        # Get all indices smaller than the new token_idx:
        indices = torch.tensor([i for i in range(vocab_size) if i < min_token_idx], device="cuda", dtype=torch.long)

        for _ in trange(epoch):
            model.zero_grad()
            outputs = model(input_ids, output_hidden_states=True)
            z = torch.index_select(outputs.decoder_hidden_states[12][0], dim=0, index=target_idxs.squeeze(-1))
            # or:
            # outputs = model(input_ids) ??
            # z = torch.index_select(outputs.encoder_last_hidden_state[0], dim=0, index=target_idxs.squeeze(-1)) ??

            loss = loss_fct(z, torch.stack(vec_targets))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # or:
            # model.get_input_embeddings().weight.grad.data[indices] = 0 ??
            # model.model.encoder.embed_positions.weight.data[indices] = 0 ??
            # model.model.encoder.embed_tokens.weight.grad[indices] = 0 ??
            model.model.shared.weight.grad[indices] = 0
            optimizer.step()
            scheduler.step()

            # try to fix the feed-forward bug
            outputs = model(input_ids)
            bert_z = torch.index_select(outputs.encoder_last_hidden_state[0], dim=0, index=target_idxs.squeeze(-1))

        # get the z* for classification
        vec = model.get_input_embeddings()(token_idxs).squeeze(1)[0]  # this is z*; [0] because all the same
        vec_array = vec.cpu().detach().numpy()
        z_list.append(vec_array)
        loss_list.append(str(loss.cpu().detach().numpy()))

        # save checkpoints
        np.save(CACHE + "temp_z_arrays_mbart.npy", np.array(z_list))
        np.save(CACHE + "temp_loss_arrays_mbart.npy", np.array(loss_list))

        s = 'Final loss={a}'.format(a=str(loss.cpu().detach().numpy()))
        print(s)

        return model

    def _get_target_embed(self, target, model):
        input_ids, gather_indexes = self.builder.encode(target[0])
        target_idx = gather_indexes[target[1] + 1][0]
        model.eval()
        with torch.no_grad():
            # Find the learning target x
            input_ids = input_ids.to('cuda')
            outputs = model(input_ids)
            # encoder relevant for downstream tasks
            x_target = outputs.encoder_last_hidden_state[0, target_idx]
        return x_target

    def _freeze(self, model):
        for name, param in model.named_parameters():
            if 'model.encoder.embed_positions' in name or 'model.decoder.embed_positions' in name:
                param.requires_grad = True
            elif 'model.shared' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        return model

    def _format(self, results):  # new format
        reval = []
        for item in results:
            if "generated_text" in item.keys():
                generated_text = item["generated_text"]
                reval.append(generated_text)
            else:
                token_str = item['token_str']
                score = item['score']
                s = ':'.join([token_str, str(score)])
                reval.append(s)
        return reval

    def _predict_z(self, model, query):
        input_ids, gather_indexes = self.builder.encode(query[0])
        # target_idx is the index of target word in the token list.
        target_idx = gather_indexes[query[1] + 1][0]
        input_ids = input_ids.to('cuda')
        outputs = model(input_ids)
        with torch.no_grad():
            logits = outputs.logits[0, target_idx, :]
        probs = logits.softmax(dim=0)
        values, predictions = probs.topk(5)
        reval = []
        for v, p in zip(values.tolist(), predictions.tolist()):
            s = {
                'score': v,
                'token_str': self.builder.tokenizer.convert_ids_to_tokens(p)
            }
            reval.append(s)
        return reval


def load_data(path: TextIO) -> List[Example]:
    reval = []
    with jsonlines.open(path) as reader:
        for obj in reader:
            target = (obj['target'], obj['target_idx'])
            query = (obj['query'], obj['query_idx'])
            reval.append((target, query))
    return reval


def get_lowest_loss_arrays(z_list, loss_list):
    z_array = np.array(z_list)
    loss_array = np.array(loss_list)

    loss_list = loss_array.tolist()
    z_list = []  # list of arrays

    loss_list = list(map(float, loss_list))

    # print(z_array)
    for vec in z_array:
        # print("vec.shape", vec.shape) #(768,)
        z_list.append(vec)

    # empty lists
    z_temp = []
    loss_temp = []

    # 5 initializations
    r = int(len(loss_list) / 5)

    for i in range(r):
        k = 0
        for j in range(5):
            if k == 0:

                k = loss_list[5 * i + j]
                z = z_list[5 * i + j]
            else:
                if loss_list[5 * i + j] < k:
                    k = loss_list[5 * i + j]
                    z = z_list[5 * i + j]
                else:
                    continue

        z_temp.append(z)
        loss_temp.append(k)

    z_temp_array = np.array(z_temp)

    return z_temp_array


if __name__ == '__main__':

    z_list = []
    z_eps_list = []
    loss_list = []
    outputs_list = []

    with open(QUERIES_PATH) as json_file:
        data = json.load(json_file)

    # Group the dataset into a list of lists where the label of the dictionaries is identical:
    data.sort(key=lambda x: x["label"])
    data = [list(group) for _, group in itertools.groupby(data, key=lambda x: x["label"])]

    tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="de_DE", tgt_lang="de_DE")
    builder = DataBuilder(tokenizer)
    co = Coercion(builder)
    for group in data:
        co.coercion(group)
        print('==' * 40)

    result = get_lowest_loss_arrays(z_list, loss_list)

    np.save(DIR_OUT + 'pseudowords_comapp.npy', result)

Here is an example part from the json file CoMaPP_all.json I’m using:

[{"label": "jetzt1631", "target1": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .", "target1_idx": 8, "query": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es <mask> .", "query_idx": 8}, {"label": "heisst1631", "target1": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .", "target1_idx": 9, "query": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es <mask> .", "query_idx": 9}, {"label": "es1631", "target1": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .", "target1_idx": 10, "query": "«Ich werde nur ganz wenig feiern , denn jetzt heisst es <mask> .", "query_idx": 10}, {"label": "jetzt1631", "target1": "Aber jetzt heisst es vorwärtsschauen .", "target1_idx": 1, "query": "Aber jetzt heisst es <mask> .", "query_idx": 1}, {"label": "heisst1631", "target1": "Aber jetzt heisst es vorwärtsschauen .", "target1_idx": 2, "query": "Aber jetzt heisst es <mask> .", "query_idx": 2}, {"label": "es1631", "target1": "Aber jetzt heisst es vorwärtsschauen .", "target1_idx": 3, "query": "Aber jetzt heisst es <mask> .", "query_idx": 3}]

CoMaPP_Dataset.csv looks as follows:

label,query,mask,ambigous_word
jetzt1631,"«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .",realistisch bleiben,jetzt
heisst1631,"«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .",realistisch bleiben,heisst
es1631,"«Ich werde nur ganz wenig feiern , denn jetzt heisst es realistisch bleiben .",realistisch bleiben,es
jetzt1631,Aber jetzt heisst es vorwärtsschauen .,vorwärtsschauen,jetzt
heisst1631,Aber jetzt heisst es vorwärtsschauen .,vorwärtsschauen,heisst
es1631,Aber jetzt heisst es vorwärtsschauen .,vorwärtsschauen,es

I have the following problem after training: As soon as I test a new_query in the loop for new_query in set(new_queries):, I get outputs like this:

new_query: '#TOKEN# für <mask> !"" , schrieb der Milliardär am Mittwochmorgen ( Ortszeit ) in dem Kurzbotschaftendienst .'
output: [{'generated_text': '#TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN#'}, {'generated_text': '#TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN#'}, {'generated_text': 'na #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN#'}, {'generated_text': 'con #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN#'}, {'generated_text': 'Con #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN# #TOKEN#'}]

Using the debugger, I found that the probability of #TOKEN# == 1.0; whereas the other tokens’ probabilities are 0.0.

What did I do wrong? I know that BERT and (m)BART are different architectures, but mBART also uses a Masked Language Model as part of its architecture, so I think it should be possible to do something similar to what has been done with BERT in the original code.

1 Like

Pinging @lewtun since they might know more. Thanks, @muhtasham :ok_hand: