Hi everyone!
As part of our research work, we are attempting to re-produce distilgpt2. We downloaded openwebtext, binazrized it as indicated here, extracted the students weights, used the (almost) same configurations as in research_projects/distillation:
python -m torch.distributed.launch \ --nproc_per_node=$N_GPU_NODE \ --nnodes=$N_NODES \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT \ train.py \ --fp16 \ --force \ --gpus $WORLD_SIZE \ --student_type gpt2 \ --student_config training_configs/distilgpt2.json \ --student_pretrained_weights ./student/pytorch_model.bin \ --teacher_type gpt2 \ --teacher_name gpt2 \ --alpha_ce 5.0 --alpha_cos 1.0 --alpha_clm 0.5 \ --freeze_pos_embs \ --dump_path my_dir \ --data_file data/owt.pickle \ --token_counts data/token_owt.pickle
We kept the default values for the rest of hyper-parameters. However, the model is not converging (perplexity over 80 for wikitext103 test set). Can anyone confirm whether the settings above are correct or not? Thanks a lot!