How to extract the "student" model after distillation?

Hey there!

I am using your distillation script (thanks for sharing it!) and based on the dumped checkpoints I see, it seems that they contain both the teacher and the student.
Assuming that my observation is correct, how can I dump only the student sub-model?

@sshleifer wondering if you have any thoughts.

Great Q!

the saved best_tfmr directory has only student the student saved (in huggingface format).

There is also a pytorch lightning weights_only checkpoint you could pass to ModelCheckpoint here. Be aware that this might break --do_predict/trainer.test, which you can overcome by running eval as a second step, roughly:

# Define useful aliases
run_distributed_eval () {
	proc=$1
	m=$2
	dd=$3
	sd=$4
	shift
	shift
	shift
	shift
	python -m torch.distributed.launch --nproc_per_node=$proc  run_distributed_eval.py \
		--model_name $m  --save_dir $sd --data_dir $dd $@
}
eval_best () {
	proc=$1
	m=$2
	dd=$3
	shift
	shift
	shift
	run_distributed_eval $proc $m/best_tfmr $dd $m/ $@
}

Finally, run:

eval_best 1 output_dir

(if you have more gpus, change the first arg)

Okay, you’re right! Thanks!