BART model fine-tuning give unexpected not relevant results

I using a BART pre-train model(facebook/bart-large-cnn) for text generation.

Problem-
Without fine-tuning I generated a relevant sentence for given input. but not matching for my task.
So I tried to fine-tune it using “transformers/examples/pytorch/language-modeling/run_clm.py” and fine-tune is success.
After the fine-tuning I get worst result from the model with not matching for the input sequence.

input text - “Bill has 9 apple and Jim”

Results-
Using pre-trained model - “Bill has 9 apple and Jim has 8 apple trees. Bill has been married to Jim for more than 30 years. Jim and Bill have 9 children together and 9 grandchildren together.”

After fine-tuning - “mutants mutants ogre mutants ogre tutor mutants tutor fuelled mutants fuelled tapes mutants tapes similarities mutants similarities dolls mutants dolls slowdown mutants slowdown Crazy mutants Crazy accelerate mutants accelerate zombies mutants zombies Cum mutants Cum”

(After fine-tuning this generated output is not relevant to the input and meaning less and also the given sentence part not having with the output)

cmd = '''
python '/content/gdrive/MyDrive/Colab Notebooks/Text_generation/transformers/examples/pytorch/language-modeling/run_clm.py' \
    --model_name_or_path facebook/bart-large-cnn \
    --train_file {0} \
    --validation_file {1} \
    --do_train \
    --do_eval \
    --num_train_epochs 3 \
    --overwrite_output_dir \
    --per_device_train_batch_size 2 \
    --output_dir {2}
'''.format(file_train, file_eval, weights_dir)
  • file_tain & file_eval are text files for fine-tuning task

Model

tokenizer = BartTokenizer.from_pretrained(weights_dir)
model = BartForConditionalGeneration.from_pretrained(weights_dir)

input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt', max_length=30)['input_ids'].to(torch_device)
text_ids = model.generate(input_ids ,
                             num_beams=4,
                             length_penalty=2.0,
                             max_length=50,
                              min_length=30,
                             no_repeat_ngram_size=2)
generated_txt = tokenizer.decode(text_ids.squeeze(), skip_special_tokens=True)

How can I fine-tune this task for a proper relevant output?

1 Like