Training DistilGPT2

Hello!
I am trying to find resources/code samples to retrain the DistilGPT2 model with text I have preprocessed myself, but could not find any. Most of the documentation relates to DistilBert and it’s uses.
Furthermore, I also have trained a gpt2-simple (tensorflow based) model. If there is any way to distil the same, it will help me too!
Thanks for your help.

1 Like

Hi @abhilashpal, you can find distillation code here. The same script that produces distillbert can be used for GPT-2, it’s not documented though.

you should be able to use this command after processing your dataset

python train.py \
    --student_type gpt2 \
    --student_config training_configs/distilgpt2.json \
    --teacher_type gpt2 \
    --teacher_name gpt2 \ # or your own teacher model
    --alpha_ce 5.0  --alpha_cos 1.0 --alpha_clm 0.5 \
    --freeze_pos_embs \
    --dump_path serialization_dir/my_first_training \
    --data_file data/binarized_text.bert-base-uncased.pickle \ # your data path
    --token_counts data/token_counts.bert-base-uncased.pickle \ # your own pickle file path
    --force # overwrites the `dump_path` if it already exists.

pinging @julien-c for more info

2 Likes

Thanks for the reply. I ran into one problem whilst running the distilbert binarization. Can anyone tell me if this means that my data per line exceeds what the distilbert model expects?

!python scripts/binarized_data.py \
--file_path data/dataemail.txt \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file data/binarized_text

07/19/2020 07:25:56 - WARNING - transformers.tokenization_utils_base - Token indices sequence length is longer than the specified maximum sequence length for this model (626 > 512). Running this sequence through the model will result in indexing errors

@abhilashpal, that’s what it looks like to me. I don’t have any direct experience with the distillation examples, but I took a quick look at the Distillbert and Bert paper and 512 looks like their max token length.

I ran the previous command mentioned on the readme and it did output the binary file. However, another file is missing, token_counts.distilgpt2.pickle . How to generate that file as well?

On the token counter script, it returns IndexError: list assignment index out of range Someone know how to resolve that error?