Hello all,
I am trying distillation of wav2vec2base-960h on timit dataset by having student model response match teacher model using kl_div loss . For this I am wrote CustomTrainer for compute_loss
function
from transformers import Trainer
from torch.nn.functional import kl_div
class CustomTrainer(Trainer):
def __init__(self, model,args = None,data_collator = None,train_dataset = None,eval_dataset = None,tokenizer = None,
model_init = None,compute_metrics = None,callbacks = None,optimizers = (None,None)):
super(CustomTrainer,self).__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init,
compute_metrics, callbacks, optimizers)
def compute_loss(self, model_obj, inputs, return_outputs=False):
labels = inputs["labels"]
outputs_student = model_obj(**inputs)
with torch.no_grad():
outputs_teacher = teacher_model(**inputs)
# outputs_student = model_obj.forward(inputs['input_values'].to(DEVICE), is_student=True)
# outputs_teacher = model_obj.forward(inputs['input_values'].to(DEVICE), is_student=False)
student_logits = outputs_student.get("logits")
teacher_logits = outputs_teacher.get("logits")
# print(student_logits.shape, teacher_logits.shape)
# matching teacher response loss with student
kl_div_loss = kl_div(student_logits, teacher_logits, reduction='batchmean')
return (kl_div_loss, student_logits) if return_outputs else kl_div_loss
Now the issue is whenever I pass per_device_eval_batch_size
more than 2,I am getting the following error which doesn’t occur when eval batch size is less than 2
training_args = TrainingArguments(
disable_tqdm=False,
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=32,
per_device_eval_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=30,
fp16=True,
gradient_checkpointing=True,
save_steps=200,
eval_steps=200,
logging_steps=200,
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
# push_to_hub=True
)
trainer = CustomTrainer(
model=student_model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=processor.feature_extractor,
)
trainer.train()
Any help as to why this occurs? Thanks in advance…