Hello,
I am trying to incorporate knowledge distillation loss into the Seq2SeqTrainer. The training script that I use is similar to the run_summarization script. It works for cpu and 1 gpu but freezes when I try run on multiple GPUs (stuck at the first batch). Even when I set use_kd_loss
to False (the loss is computed by the super call only), it still does not work on multiple GPUs. Below is the trainer that I am using, any help would be greatly appreciated!
class Seq2SeqKDTrainer(Seq2SeqTrainer):
"""
"""
def __init__(self,
*args,
use_kd_loss=False,
teacher_model=None,
temperature=2.0,
normalize_hidden=False,
alpha_data=1.0,
alpha_logits=0.0,
alpha_hidden=0.0,
**kwargs):
super().__init__(*args, **kwargs)
self.use_kd_loss = use_kd_loss
self.teacher_model = teacher_model
# Get the configurations to compare sizes
self.student_config_dict = self.model.config.to_diff_dict()
self.teacher_config_dict = self.teacher_model.config.to_diff_dict()
self.temperature = temperature
self.normalize_hidden = normalize_hidden
self.alpha_data = alpha_data
self.alpha_logits = alpha_logits
self.alpha_hidden = alpha_hidden
def compute_loss(self, model, inputs, return_outputs=False):
# Update inputs to output hidden states and in form of a dictionary
inputs["output_hidden_states"] = self.use_kd_loss
inputs["return_dict"] = True
# Compute cross-entropy data loss, which is identical to the default loss of Seq2SeqTrainer
data_loss, student_outputs = super().compute_loss(model, inputs, return_outputs=True)
# Compute KD component losses
# Initialize losses to all 0s and only update if we use knowledge-distillation loss
enc_hidden_loss, dec_hidden_loss, logits_loss = 0.0, 0.0, 0.0
if self.use_kd_loss:
# Set up variables
input_ids, source_mask, labels = inputs["input_ids"], inputs["attention_mask"], inputs["labels"]
pad_token_id = self.tokenizer.pad_token_id
decoder_input_ids = shift_tokens_right(input_ids=labels,
pad_token_id=pad_token_id,
decoder_start_token_id=self.teacher_model.config.decoder_start_token_id)
teacher_model = self.teacher_model.to(input_ids.device)
teacher_outputs = teacher_model(input_ids=input_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
output_hidden_states=True,
return_dict=True,
use_cache=False)
# Compute logits loss
decoder_mask = decoder_input_ids.ne(pad_token_id)
logits_loss = self._compute_logits_loss(student_logits=student_outputs.logits,
teacher_logits=teacher_outputs.logits,
mask=decoder_mask,
temperature=self.temperature)
# Only compute encoder's hidden loss if the student's encoder is smaller
if self.student_config_dict["encoder_layers"] < self.teacher_config_dict["encoder_layers"]:
enc_hidden_loss = self._compute_hidden_loss(
student_hidden_states=student_outputs.encoder_hidden_states,
teacher_hidden_states=teacher_outputs.encoder_hidden_states,
attention_mask=source_mask,
teacher_layer_indices=self.student_config_dict["encoder_layer_indices"],
normalize=self.normalize_hidden
)
# Only compute decoder's hidden loss if the student's decoder is smaller
if self.student_config_dict["decoder_layers"] < self.teacher_config_dict["decoder_layers"]:
dec_hidden_loss = self._compute_hidden_loss(
student_hidden_states=student_outputs.decoder_hidden_states,
teacher_hidden_states=teacher_outputs.decoder_hidden_states,
attention_mask=decoder_mask,
teacher_layer_indices=self.student_config_dict["decoder_layer_indices"],
normalize=self.normalize_hidden
)
total_loss = self.alpha_data * data_loss + \
self.alpha_logits * logits_loss + \
self.alpha_hidden * (enc_hidden_loss + dec_hidden_loss)
return total_loss
@staticmethod
def _compute_logits_loss(student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor,
temperature: float = 2.0):
sel_mask = mask[:, :, None].expand_as(student_logits)
vocab_size = student_logits.size(-1)
# Select logits based on mask
student_logits_select = torch.masked_select(student_logits, sel_mask).view(-1, vocab_size)
teacher_logits_select = torch.masked_select(teacher_logits, sel_mask).view(-1, vocab_size)
assert (
student_logits_select.shape == teacher_logits_select.shape
), "Expected tensors of the same size. Got student: {}, teacher: {}".format(student_logits_select.shape,
teacher_logits_select.shape)
# Compute logits loss
logits_loss_fct = nn.KLDivLoss(reduction="batchmean")
logits_loss = (
logits_loss_fct(
F.log_softmax(student_logits_select / temperature, dim=-1),
F.log_softmax(teacher_logits_select / temperature, dim=-1)
) * temperature ** 2
)
return logits_loss
@staticmethod
def _compute_hidden_loss(student_hidden_states: Tuple[torch.Tensor],
teacher_hidden_states: Tuple[torch.Tensor],
attention_mask: torch.Tensor,
teacher_layer_indices: list,
normalize: bool = False
):
mask = attention_mask.to(student_hidden_states[0]) # Type and/or device conversion
valid_count = mask.sum() * student_hidden_states[0].size(-1) # Get valid count
# Stack hidden states
# Here we skip the first hidden state which is the output of the embeddings
student_hidden_stack = torch.stack([state for state in student_hidden_states[1:]])
teacher_hidden_stack = torch.stack([teacher_hidden_states[i] for i in teacher_layer_indices])
assert (
student_hidden_stack.shape == teacher_hidden_stack.shape
), "Expected tensors of the same size. Got student: {}, teacher: {}".format(student_hidden_stack.shape,
teacher_hidden_stack.shape)
# Normalize if specified
if normalize:
student_hidden_stack = F.layer_norm(student_hidden_stack, student_hidden_stack.shape[1:])
teacher_hidden_stack = F.layer_norm(teacher_hidden_stack, teacher_hidden_stack.shape[1:])
# Compute MSE loss
loss_fct = nn.MSELoss(reduction="none")
mse_loss = loss_fct(student_hidden_stack, teacher_hidden_stack)
masked_mse_loss = (mse_loss * mask.unsqueeze(dim=0).unsqueeze(dim=-1)).sum() / valid_count
return masked_mse_loss