Bert2bert translator?

Hello,

I am trying to put my hands on transformers (this is my first project with transformers). I decided to do a bert2bert translator, as it one of those tested in the following paper https://arxiv.org/pdf/1907.12461

I put my tests here Bert2Bert_translator/Bert_translator.ipynb at 0fb904c480df2a2de53f51e9b9198b65b6fcf770 · jclary-31/Bert2Bert_translator · GitHub

I used the EncoderDecoderModel to combine one Bert in encoder mode and another one in decoder mode. I then fine tuned the model but something is off

maybe it is because I use the wrong Bert checkpoint, maybe it is because encoder inputs are not correct (but this step should be automatic, maybe it is something else. Should I separate encoder and decoder?

I don’t know where the problem lies,
I tried on bigger dataset, it changes nothing. In the end my final in a translation task will still be something as ‘ [CLS] [CLS] [CLS]
’. So I think the issue is in the conception. Something I missed or understand wrong.

I checked in forums, Github, website, and found no concrete example on such translator


Do you know what is wrong? It is in the code or in the conception?

Thanks

1 Like

There seem to be several known cases. I tried having AI write some demo code.

import torch
from transformers import (
    BertTokenizerFast, BertConfig, BertLMHeadModel, BertModel,
    AutoModel, EncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
)

torch.manual_seed(0)
enc = dec = "bert-base-uncased"
tok_src = BertTokenizerFast.from_pretrained(enc)
tok_tgt = BertTokenizerFast.from_pretrained(dec)

# ---------- WRONG_1: BOS loop risk (labels include BOS + manual decoder_input_ids)
dec_cfg = BertConfig.from_pretrained(dec, is_decoder=True, add_cross_attention=True)
bad_train = EncoderDecoderModel(
    encoder=AutoModel.from_pretrained(enc),
    decoder=BertLMHeadModel.from_pretrained(dec, config=dec_cfg),
)
X = tok_src(["i like tea"], return_tensors="pt", padding=True, truncation=True)
Y = tok_tgt(["j'aime le thé"], return_tensors="pt", padding=True, truncation=True)  # has [CLS]
labels = Y.input_ids.clone(); labels[labels == tok_tgt.pad_token_id] = -100
_ = bad_train(input_ids=X["input_ids"], attention_mask=X["attention_mask"],
              decoder_input_ids=Y.input_ids, labels=labels)  # ❌
gen = bad_train.generate(
    X["input_ids"], attention_mask=X["attention_mask"], max_new_tokens=8,
    decoder_start_token_id=tok_tgt.cls_token_id, eos_token_id=tok_tgt.sep_token_id, pad_token_id=tok_tgt.pad_token_id
)
print("WRONG_1 gen ids:", gen[0][:8].tolist())

# ---------- WRONG_2: decoder lacks LM head / cross-attn
plain_decoder = BertModel.from_pretrained(dec)  # ❌
broken = EncoderDecoderModel(encoder=AutoModel.from_pretrained(enc), decoder=plain_decoder)
try:
    lbl2 = tok_tgt(["les chats sont mignons"], return_tensors="pt",
                   padding=True, truncation=True, add_special_tokens=False).input_ids
    lbl2[lbl2 == tok_tgt.pad_token_id] = -100
    _ = broken(input_ids=X["input_ids"], attention_mask=X["attention_mask"], labels=lbl2)
    print("WRONG_2 ran (decoder misconfigured)")
except Exception as e:
    print("WRONG_2 error:", type(e).__name__)

# ---------- CORRECT: set decoder_start_token_id ON CONFIG before forward
dec_cfg_ok = BertConfig.from_pretrained(dec, is_decoder=True, add_cross_attention=True)
good = EncoderDecoderModel(
    encoder=AutoModel.from_pretrained(enc),
    decoder=BertLMHeadModel.from_pretrained(dec, config=dec_cfg_ok),
)
# Required for loss computation (right-shift uses this)
good.config.decoder_start_token_id = tok_tgt.cls_token_id
good.config.eos_token_id = tok_tgt.sep_token_id
good.config.pad_token_id = tok_tgt.pad_token_id
good.config.vocab_size = good.config.decoder.vocab_size
good.config.tie_encoder_decoder = False

X2 = tok_src(["cats are cute", "i like tea"], return_tensors="pt", padding=True, truncation=True)
Y2 = tok_tgt(["les chats sont mignons", "j'aime le thé"], return_tensors="pt",
             padding=True, truncation=True, add_special_tokens=False)  # no [CLS]
labels2 = Y2.input_ids.clone(); labels2[labels2 == tok_tgt.pad_token_id] = -100
_ = good(input_ids=X2["input_ids"], attention_mask=X2["attention_mask"], labels=labels2)  # ✅ no error

gen2 = good.generate(
    X2["input_ids"], attention_mask=X2["attention_mask"],
    num_beams=4, max_new_tokens=24, no_repeat_ngram_size=3, early_stopping=True,
    decoder_start_token_id=tok_tgt.cls_token_id, eos_token_id=tok_tgt.sep_token_id, pad_token_id=tok_tgt.pad_token_id
)
print("CORRECT gen:", [tok_tgt.decode(g, skip_special_tokens=True) for g in gen2])

# ---------- CHECK: known-good BERT2BERT
name = "google/bert2bert_L-24_wmt_en_de"
tok_g = AutoTokenizer.from_pretrained(name, pad_token="<pad>", bos_token="<s>", eos_token="</s>")
mdl_g = AutoModelForSeq2SeqLM.from_pretrained(name)
ids = tok_g("Would you like a coffee?", return_tensors="pt", add_special_tokens=False).input_ids
print("CHECK gen:", tok_g.decode(mdl_g.generate(ids, num_beams=4, max_new_tokens=32)[0], skip_special_tokens=True))

#WRONG_1 gen ids: [101, 6730, 6730, 6730, 6730, 6730, 6730, 6730]
#WRONG_2 error: ValueError
#CORRECT gen: ['played rule rule rule rules rule rule play rule play play rule rule pass rule play pass rule rule win rule rule flow rule', 'the. and and and pass pass pass rule rule rule pass pass be rule rule be rule pass rule pass be pass pass']
#CHECK gen: Haben Sie Lust auf einen Kaffee?

hello

I made a small and quick test code following your advices Bert2Bert_translator/bert2bert_quicktest.ipynb at main · jclary-31/Bert2Bert_translator · GitHub

So,

  1. the [CLS][CLS]
.. is no longer generated. I am not sure if the resolution was to use BERLLMHead or the option ‘decoder_start_token_id=tok_tgt.cls_token_id’ when generating,
 or both.
  2. the solution generated make no sense at all. And from the test I made, result (=generated solution) mostly depends on no_repeat_ngram_size and num_beam parameters.

when no_repeat_ngram is in the parameters, some word will be generated, without this parameters the same word is repeated again and again. It is like the ‘#CORRECT gen: ['played rule rule rule rules rule rule’ in your last answer.

In my main code, where i test fine tuning, if I don’t use the parameter norepeat_ngram, the text generated remain ‘[CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 
’
If I use the parameters norepeat_ngram=3, the text generated is
[CLS] [PAD] [PAD] [PAD], [PAD] [PAD] of [PAD] [PAD] and [PAD] [PAD]esian [PAD] [PAD] lucas [PAD] [PAD]chfield [PAD]

So I think there is still head attention issues. Do you you know how to fix it? Should I update the Bert_translator.ipynb on github so you can see it?

1 Like

The above solution is just to suppress PAD tokens

When actually implementing this, you will need to perform actual training and use a tokenizer that supports both languages.

# pip install -U transformers datasets
import random, math
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, BertConfig, BertLMHeadModel, EncoderDecoderModel
)

# ---- config
SEED = 0
SRC_CKPT = "bert-base-uncased"              # encoder (EN)
TGT_CKPT = "bert-base-multilingual-cased"   # decoder (FR-capable)
MAX_SRC_LEN = 96
MAX_TGT_LEN = 96
BATCH_SIZE = 8
EPOCHS = 10                                 # raise to 20–30 if not overfitting
LR = 5e-5

random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- tokenizers
tok_src = AutoTokenizer.from_pretrained(SRC_CKPT)
tok_tgt = AutoTokenizer.from_pretrained(TGT_CKPT)
PAD_ID = tok_tgt.pad_token_id
EOS_ID = tok_tgt.sep_token_id
BOS_ID = tok_tgt.cls_token_id

# ---- model: BERT encoder + BERT LM-head decoder with cross-attn
dec_cfg = BertConfig.from_pretrained(TGT_CKPT, is_decoder=True, add_cross_attention=True)
model = EncoderDecoderModel(
    encoder=AutoModel.from_pretrained(SRC_CKPT),
    decoder=BertLMHeadModel.from_pretrained(TGT_CKPT, config=dec_cfg),
).to(device)
# required special ids for training (right-shift) and decode
model.config.decoder_start_token_id = BOS_ID
model.config.eos_token_id = EOS_ID
model.config.pad_token_id = PAD_ID
model.config.tie_encoder_decoder = False
model.config.vocab_size = model.config.decoder.vocab_size

# ---- tiny EN–FR set: take 100 pairs from OPUS Books
# notes: you can replace this with your own parallel lists
ds = load_dataset("Helsinki-NLP/opus_books", "en-fr", split="train")  # ~1M pairs
pairs = [(ex["translation"]["en"], ex["translation"]["fr"]) for ex in ds.select(range(2000))]
random.shuffle(pairs)
pairs = pairs[:100]  # exactly 100
src_list, tgt_list = zip(*pairs)

# ---- helpers
def build_batch(src_texts, tgt_texts):
    # source
    X = tok_src(
        list(src_texts), padding=True, truncation=True, max_length=MAX_SRC_LEN, return_tensors="pt"
    )
    # target labels: NO BOS; append EOS; mask PAD with -100
    Y = tok_tgt(
        list(tgt_texts), padding="max_length", truncation=True, max_length=MAX_TGT_LEN,
        add_special_tokens=False, return_tensors="pt"
    )["input_ids"]
    # append EOS before padding if room
    Y_fixed = torch.full_like(Y, PAD_ID)
    for i in range(Y.size(0)):
        toks = [t for t in Y[i].tolist() if t != PAD_ID]
        if len(toks) < MAX_TGT_LEN:
            toks = toks + [EOS_ID]
        toks = toks[:MAX_TGT_LEN]
        Y_fixed[i, :len(toks)] = torch.tensor(toks, dtype=Y_fixed.dtype)
    labels = Y_fixed.clone()
    labels[labels == PAD_ID] = -100

    return {k: v.to(device) for k, v in X.items()}, labels.to(device)

def collate(batch):
    s, t = zip(*batch)
    return build_batch(s, t)

# simple Dataset wrapper
class Pairs(torch.utils.data.Dataset):
    def __init__(self, srcs, tgts):
        self.s = list(srcs); self.t = list(tgts)
    def __len__(self): return len(self.s)
    def __getitem__(self, i): return self.s[i], self.t[i]

train_dl = DataLoader(Pairs(src_list, tgt_list), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

@torch.inference_mode()
def translate_samples(texts, n=5):
    X = tok_src(list(texts[:n]), return_tensors="pt", padding=True, truncation=True, max_length=MAX_SRC_LEN).to(device)
    out = model.generate(
        X["input_ids"], attention_mask=X["attention_mask"],
        num_beams=4, max_new_tokens=64, early_stopping=True,
        decoder_start_token_id=BOS_ID, eos_token_id=EOS_ID, pad_token_id=PAD_ID,
        bad_words_ids=[[PAD_ID]],          # block PAD
        repetition_penalty=1.1,            # mild
        no_repeat_ngram_size=3             # optional hygiene
    )
    return [tok_tgt.decode(o, skip_special_tokens=True) for o in out]

def show_before_after(k=5):
    print("\n--- BEFORE ---")
    preds_before = translate_samples(src_list, n=k)
    for i in range(k):
        print(f"EN: {src_list[i]}")
        print(f"FR_gold: {tgt_list[i]}")
        print(f"FR_pred: {preds_before[i]}")
        print("-")
    # train then test again
    model.train()
    opt = AdamW(model.parameters(), lr=LR)
    steps = 0
    for epoch in range(EPOCHS):
        for X, labels in train_dl:
            opt.zero_grad()
            out = model(input_ids=X["input_ids"], attention_mask=X["attention_mask"], labels=labels)
            out.loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            steps += 1
        print(f"epoch {epoch+1}/{EPOCHS} done")
    model.eval()

    print("\n--- AFTER ---")
    preds_after = translate_samples(src_list, n=k)
    for i in range(k):
        print(f"EN: {src_list[i]}")
        print(f"FR_gold: {tgt_list[i]}")
        print(f"FR_pred: {preds_after[i]}")
        print("-")

if __name__ == "__main__":
    print(f"device: {device}")
    show_before_after(k=5)

"""
--- BEFORE ---
EN: As for me, I found myself obliged, the first time for months, to face alone a long Thursday evening - with the clear feeling that the old carriage had borne away my youth forever.
FR_gold: Quant Ă  moi, je me trouvai, pour la premiĂšre fois depuis de longs mois, seul en face d’une longue soirĂ©e de jeudi – avec l’impression que, dans cette vieille voiture, mon adolescence venait de s’en aller pour toujours.
FR_pred: ##iiilililiililiiliiliilingingiingiingiingingingingiiliiliingiingiiliiliigingingillingingighingiingingiingiiliingingiiliingiigiingiingieningingioviingiinginiingiingiiingiingighinginginingingiigingi
-
EN: No one asked him who Booby was.
FR_gold: Personne ne lui demanda qui était Ganache.
FR_pred: a a a - - - a a A A A a a ad ad ad Ad Ad Ad ad ad a a, a a ae ae ae a A a A,, A A, - -,,, a,,. - - an an an,, an an - - A A - - 1 -
-
EN: M. Seurel's here .. .'
FR_gold: M. Seurel est là

FR_pred: ##ggg22233322443344423243234377799988877889979773378789786779777688
-
EN: After the ball where everything was charming but feverish and mad, where he had himself so madly chased the tall Pierrot, Meaulnes found that he had dropped into the most peaceful happiness on earth.
FR_gold: AprĂšs cette fĂȘte oĂč tout Ă©tait charmant, mais fiĂ©vreux et fou, oĂč lui-mĂȘme avait si follement poursuivi le grand pierrot, Meaulnes se trouvait lĂ  plongĂ© dans le bonheur le plus calme du monde.
FR_pred: ##iiilililiiiiliilililiiliiliigiigiigiiliiliiliingiingiingiiliilingingingiingiingiigiigingingiigiigiingiingingingiiliigiingiigingiingiigiingingiingingiigiingiiciingiingificiingiingiiciigiigiiciingi
-
EN: At half-past eight, just as M. Seurel was giving the signal to enter school, we arrived, quite out of breath, to line up.
FR_gold: À huit heures et demie, Ă  l’instant oĂč M. Seurel allait donner le signal d’entrer, nous arrivĂąmes tout essoufflĂ©s pour nous mettre sur les rangs.
FR_pred: ##jajajajanjanjanjajajanojanjanjaljanjan sal sal saljanjan sino sino sinojanjanjanojanojanojanjano sino sinojanojano sal salcolcolcolcalcalcalcolcol sal salsal sal salallallall sal sal alcolcolsalsalcolcol - - sal sal
-

--- AFTER ---
EN: As for me, I found myself obliged, the first time for months, to face alone a long Thursday evening - with the clear feeling that the old carriage had borne away my youth forever.
FR_gold: Quant Ă  moi, je me trouvai, pour la premiĂšre fois depuis de longs mois, seul en face d’une longue soirĂ©e de jeudi – avec l’impression que, dans cette vieille voiture, mon adolescence venait de s’en aller pour toujours.
FR_pred: Quant Ă  moi, je ne voulus pas pour la premiĂšre fois de soi, seul en face d une longue longue aventure de longs mois.
-
EN: No one asked him who Booby was.
FR_gold: Personne ne lui demanda qui était Ganache.
FR_pred: Personne ne lui demanda qui demanda demanda qui lui demanda demanda qu il demanda Ganache.
-
EN: M. Seurel's here .. .'
FR_gold: M. Seurel est là

FR_pred: M. Seurel est lĂ 
-
EN: After the ball where everything was charming but feverish and mad, where he had himself so madly chased the tall Pierrot, Meaulnes found that he had dropped into the most peaceful happiness on earth.
FR_gold: AprĂšs cette fĂȘte oĂč tout Ă©tait charmant, mais fiĂ©vreux et fou, oĂč lui-mĂȘme avait si follement poursuivi le grand pierrot, Meaulnes se trouvait lĂ  plongĂ© dans le bonheur le plus calme du monde.
FR_pred: DĂšs qu on le recommença plus le grand pierrot de sa sociĂ©tĂ© oĂč lui mĂȘme mĂȘme mĂȘme avait si beau.
-
EN: At half-past eight, just as M. Seurel was giving the signal to enter school, we arrived, quite out of breath, to line up.
FR_gold: À huit heures et demie, Ă  l’instant oĂč M. Seurel allait donner le signal d’entrer, nous arrivĂąmes tout essoufflĂ©s pour nous mettre sur les rangs.
FR_pred: À huit heures et demie Ă  peine, nous arrivĂąmes tout tout essoufflĂ©s sur les rangs.
-
"""

hello John, thank you very much for your help.

so,

  1. ooh sorry I forget to activate the train mode with model.train() in my small quick test. My mistake
  2. I am french, so letters as â€˜Ă©â€™ or ‘ù’ are completely natural to me, and I forgot they do not exist in english. So yes, encoder and decoder are differents.
  3. it seems that decoder does not need a BOS 
 and that EOS is not required either if the sentence is cut. I didn’t knew that, and it can change sentences. I assume decoder create BOS and EOS.

Thanks a lot for your help, I learned a lot. For example I was not aware of the repetition_penalty, nor the no_repeat_ngram_size parameters.

if I may ask , why model.config.tie_encoder_decoder = False?

1 Like

why model.config.tie_encoder_decoder = False?

I thought it would be problematic if this parameter were set to True when using it across two or more models.

tie_encoder_decoder (bool, optional, defaults to False) — Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.