How to train a translation model from scratch

I’ve been recently working on text punctuation restoration which is a problem where you have some text with missing punctuation and you want to add it back.

Reading some papers, it seems one of the best approaches is to use Transformers as if you were doing a translation, from a language which there’s no punctuation to one that has it.

I am trying to use :hugs: Hugging Face transformers, but I’ve been struggling to find good resources to learn how to train a translation network from scratch. Most of the documentation is related to other tasks and when it comes to translation, I’ve found only docs that explain how to use pre-trained models.

Could someone help me giving a direction?

I’m still trying to test a naïve approach where I want to give my model some text without punctuation and I want it to predict a new sentence where for each input token, it predicts if it’s preceded by some punctuation or not.

Sorry if there’s somewhere really obvious in the documentation I didn’t look at.

Thanks for reading this and I wish you a wonderful day!

I don’t think we have a really good example for seq2seq training from scratch right now, but you can take a look at the examples/seq2seq folder which implements fine-tuning, and adapt.

This is using pytorch-lightning.

@sshleifer can also chime in.

2 Likes

Thanks for replying!

I will take a look at examples/seq2seq.

1 Like

I would like to do fine-tuning of pretrained models on my domain-specific dataset. MarianMT pretrained models are available on many bilinguals.

I would like to know how we can train either a translation model from scratch or do the finetuning of available pretrained models.

The code for MarianMT tokenizer is there and i saw MarianMT inherits the config of BART for its architecture but i am not sure about how to fine-tune for my datasets.

So, I had a look at “examples/seq2seq”, but unfortunately I couldn’t extract much knowledge from the code. I think I need to practice my reading coding skills a bit more hehe.

I recently tried to use an EncoderDecoder Model to recreate my encoder inputs on the decoder outputs, but I couldn’t make it work properly.

I tried to overfit a model using a set of 8 sentences to check if the model was working properly, but it just kept repeating the same token over and over.

Here is the code:

import transformers

encoder_config = transformers.BertConfig()
decoder_config = transformers.BertConfig()

config = transformers.EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
model = transformers.EncoderDecoderModel(config)
model = model.to(device)

model.train()
inputs = data['input_ids']
outputs = data['target']
outputs = inputs

inputs = inputs.to(device)
outputs = outputs.to(device)

optimizer = transformers.AdamW(model.parameters())
epochs = 50

for e in range(epochs):
    optimizer.zero_grad()
    loss, output = model(input_ids=inputs,
                         decoder_input_ids=outputs,
                         labels = outputs)[:2]
    l_numerical = loss.item()
    print(f"Epoch: {e}, Loss: {l_numerical}")
    loss.backward()
    optimizer.step()

I think it may be something related to the way I am feeding it inputs for the decoder, but I am not sure.

Here are my outputs and loss values. It looks like my loss is stuck:

Epoch: 0, Loss: 10.43316650390625
Epoch: 1, Loss: 9.597984313964844
Epoch: 2, Loss: 8.604254722595215
Epoch: 3, Loss: 7.791357040405273
Epoch: 4, Loss: 7.0663299560546875
Epoch: 5, Loss: 6.433999538421631
Epoch: 6, Loss: 5.967499256134033
Epoch: 7, Loss: 5.682089328765869
Epoch: 8, Loss: 5.520726203918457
Epoch: 9, Loss: 5.435981273651123
Epoch: 10, Loss: 5.400460720062256
Epoch: 11, Loss: 5.394018650054932
Epoch: 12, Loss: 5.399378776550293
Epoch: 13, Loss: 5.400329113006592
Epoch: 14, Loss: 5.3928022384643555
Epoch: 15, Loss: 5.382742404937744
Epoch: 16, Loss: 5.369752883911133
Epoch: 17, Loss: 5.36002254486084
Epoch: 18, Loss: 5.361392021179199
Epoch: 19, Loss: 5.364934921264648
Epoch: 20, Loss: 5.367682456970215
Epoch: 21, Loss: 5.367767810821533
Epoch: 22, Loss: 5.3658857345581055
Epoch: 23, Loss: 5.3620100021362305
Epoch: 24, Loss: 5.358331680297852
Epoch: 25, Loss: 5.357394218444824
Epoch: 26, Loss: 5.359387397766113
Epoch: 27, Loss: 5.359804153442383
Epoch: 28, Loss: 5.359673023223877
Epoch: 29, Loss: 5.359203338623047
Epoch: 30, Loss: 5.357675075531006
Epoch: 31, Loss: 5.356016635894775
Epoch: 32, Loss: 5.356872081756592
Epoch: 33, Loss: 5.357171535491943
Epoch: 34, Loss: 5.356071949005127
Epoch: 35, Loss: 5.355656623840332
Epoch: 36, Loss: 5.355906963348389
Epoch: 37, Loss: 5.356597900390625
Epoch: 38, Loss: 5.354917526245117
Epoch: 39, Loss: 5.355926990509033
Epoch: 40, Loss: 5.353925704956055
Epoch: 41, Loss: 5.35430908203125
Epoch: 42, Loss: 5.355685234069824
Epoch: 43, Loss: 5.355371475219727
Epoch: 44, Loss: 5.3547282218933105
Epoch: 45, Loss: 5.3535051345825195
Epoch: 46, Loss: 5.3542351722717285
Epoch: 47, Loss: 5.355259895324707
Epoch: 48, Loss: 5.3540568351745605
Epoch: 49, Loss: 5.353936672210693
tensor([[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
        [ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
         1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996]], device='cuda:0')

I am not using any input masks for both Encoder and Decoder because all sentences are complete. All of them use all spaces from the input array.

I also used BertTokenizer on them so I think the input transformation is correct.

I hope this give you as much context as possible.

Thank you a lot! :hugs:

@sshleifer could you give me a hand? I am sure there’s something really silly in my code making it not work.

finetune.py

The way to do it with seq2seq/finetune.py is to put the docs in a directory with the following format:

train.source
train.target
val.source
val.target
test.source
test.target

line i of train.source should be “corrupted”, line i of train.target should have puncation.
same with val, test.

Then, depending on the language of the docs, run:

export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
    --learning_rate=3e-5 \
    --fp16 \
    --gpus 1 \
    --do_train \
    --do_predict \
    --n_val 1000 \
    --val_check_interval 0.1 --data_dir $DATA_DIR \
    --model_name_or_path bart-base

From scratch

You can also use finetune.py to train from scratch by calling, for example,

config = BartConfig(...whatever you want..)
model = BartForConditionalGeneration.from_pretrained(config)
model.save_pretrained('rand_bart')

But I would not do that in your position.

(If the docs are not in english you could try starting from mbart.)

From your code

If you want to do it from scratch as a learning exercise, make sure your decoder has a causal mask so it can’t attend to the next word.
I don’t totally understand the tensor you printed, but it looks concerningly redundant.

Let me know if that helps, sorry for the slow response.

1 Like

So, that tensor shows what are the tokens predicted for each sentence.

In that example, it started with 101 which means start of sentence, and then it kept repeating 1996 which means “The”.

I’ve tried the same approach using Bart, but with no success.

Here’s an example:

from torch.utils.data import DataLoader

model = transformers.BartForConditionalGeneration.from_pretrained('facebook/bart-large').to(device)
tokenizer = transformers.BartTokenizer.from_pretrained('facebook/bart-large')

dl = DataLoader(dataset['train'], batch_size=8)

optimizer = torch.optim.AdamW(model.parameters())

epochs = 11
max_len = 50

data = next(iter(dl))

for e in range(epochs):

    en_tokenized_data = tokenizer.batch_encode_plus(
        data['translation']['en'],
        max_length=max_len, 
        truncation=True,
        pad_to_max_length=True, 
        return_attention_mask=True, 
        return_tensors='pt'
    )

    en_tokenized_data = en_tokenized_data.to(device)
  
    model.train()
    optimizer.zero_grad()

    result = model(
        input_ids=en_tokenized_data['input_ids'],
        attention_mask=en_tokenized_data['attention_mask'],
        decoder_input_ids=en_tokenized_data['input_ids'],
        decoder_attention_mask=en_tokenized_data['attention_mask'],
        labels=en_tokenized_data['input_ids']
    )

    loss, output = result[:2]
    print(f"Epoch: {e}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()

Before starting the training phase, it seems the model knows how to rebuild the original sentence:

Examples: 
 <s>ResResumption of the session</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Original:
'Resumption of the session'

However, after 10 epochs, it starts generating garbage:

Examples: 
 <s> the the the the the the the the the the the the the the the the the the the

I think the issue lies in how to use the causal mask. I am not sure how to pass it to the model. It’s not an argument for the forward pass. Reading the docs I thought it was generated automatically by the model.

Or maybe I am trying to use the incorrect model for this task. I am pretty confused tbh.

Am I doing the training process correctly? In this I just wanted to learn how to reconstruct 8 sentences from my dataset.

Thanks!

Sorry to bother you again and I hope this is the last time I’ll need to do it, but could you take a look at my previous post? @sshleifer

I’ve recently checked BERT paper and I understand why my training is not incorrect… At least I think. BERT is trained for both next sentence prediction and masked language modeling, so maybe my inputs and outputs are not what the model expects. I still have to take a look on Bart to understand it better though.

I thought using the tokenizer was enough to create the inputs/outputs expected by the model.

Thanks again! :hugs:

Hi, I am currently facing a similar problem. Did you find a solution? @mikaelsouza