T5 Pointer Generator

Hi everyone, I’m on a text summarization task. The model that I use is mt0-base with MT5ForConditionalGeneration class, I combined this class with a network pointer, here is my code:

class MT5PointerGenerator(MT5ForConditionalGeneration):

_keys_to_ignore_on_load_missing = ["linear_copy.weight", "linear_copy.bias"]

def __init__(self, config):
    super().__init__(config)
    self.linear_copy = nn.Linear(self.model_dim, 1)
    
def forward(self,
            input_ids = None,
            attention_mask = None,
            decoder_input_ids = None,
            decoder_attention_mask = None,
            head_mask = None,
            decoder_head_mask = None,
            cross_attn_head_mask = None,
            encoder_outputs = None,
            past_key_values = None,
            inputs_embeds = None,
            decoder_inputs_embeds = None,
            labels = None,
            use_cache = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None):
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
    if head_mask is not None and decoder_head_mask is None:
        if self.config.num_layers == self.config.num_decoder_layers:
            decoder_head_mask = head_mask
    
    if encoder_outputs is None:
        encoder_outputs = self.encoder(input_ids = input_ids,
                                       attention_mask = attention_mask,
                                       inputs_embeds = inputs_embeds,
                                       head_mask = head_mask,
                                       output_attentions = output_attentions,
                                       output_hidden_states = output_hidden_states,
                                       return_dict = return_dict)
        
    elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
        encoder_outputs = BaseModelOutput(last_hidden_state = encoder_outputs[0],
                                          hidden_states = encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                                          attentions = encoder_outputs[2] if len(encoder_outputs) > 2 else None)
    
    hidden_states = encoder_outputs[0]
    
    if self.model_parallel:
        torch.cuda_set_device(self.decoder.first_device)
        
    if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
        decoder_input_ids = self._shift_right(labels)
    
    if past_key_values is not None:
        assert labels is None, "Decoder should not use cached key value states when training."
        if decoder_input_ids is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]
        if decoder_inputs_embeds is not None:
            decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
    
    if self.model_parallel:
        torch.cuda.set_device(self.decoder.first_device)
        hidden_states = hidden_states.to(self.decoder.first_device)
        if decoder_input_ids is not None:
            decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.decoder.first_device)
        if decoder_attention_mask is not None:
            decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
    
    decoder_outputs = self.decoder(input_ids = decoder_input_ids,
                                   attention_mask = decoder_attention_mask,
                                   inputs_embeds = decoder_inputs_embeds,
                                   past_key_values = past_key_values,
                                   encoder_hidden_states = hidden_states,
                                   encoder_attention_mask = attention_mask,
                                   head_mask = decoder_head_mask,
                                   cross_attn_head_mask = cross_attn_head_mask,
                                   use_cache = use_cache,
                                   output_attentions = output_attentions,
                                   output_hidden_states = output_hidden_states,
                                   return_dict = return_dict)
    sequence_output = decoder_outputs[0]
    
    if self.model_parallel:
        torch.cuda.set_device(self.encoder.first_device)
        self.lm_head = self.lm_head.to(self.encoder.first_device)
        sequence_output = sequence_output.to(self.lm_head.weight.device)
    
    if self.config.tie_word_embeddings:
        sequence_output = sequence_output * (self.model_dim ** -0.5)
    
    lm_logits = self.lm_head(sequence_output)
    
    # Copy distribution
    cross_attentions = decoder_outputs["cross_attentions"][-1]
    cross_attentions = torch.mean(cross_attentions, dim = 1)
    
    # Probability of copying
    p_copy = torch.sigmoid(self.linear_copy(sequence_output))
    
    # Merge distribution
    original_word_pro = torch.softmax(lm_logits, dim = -1) * (1 - p_copy) #[batch, sequence_length, vocab_size]
    copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(batch, target_length, encoder_length)
    lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions*p_copy)
    
    eps = 1e-7
    lm_logits = torch.log(lm_logits + eps)
    
    loss = None
    if labels is not None:
        loss_fct = NLLLoss(ignore_index = -100)
        loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        
    if not return_dict:
        output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
        return ((loss,) + output) if loss is not None else output
    
    return Seq2SeqLMOutput(loss = loss,
                           logits = lm_logits,
                           past_key_values = decoder_outputs.past_key_values,
                           decoder_hidden_states = decoder_outputs.hidden_states,
                           decoder_attentions = decoder_outputs.attentions,
                           cross_attentions = decoder_outputs.cross_attentions,
                           encoder_last_hidden_state = encoder_outputs.last_hidden_state,
                           encoder_hidden_states = encoder_outputs.hidden_states,
                           encoder_attentions = encoder_outputs.attentions)

model = MT5PointerGenerator(…)
When I use model to train and calculate loss there is no problem, however when i use model.generate(…) method I get this error:

in :2 β”‚
β”‚ β”‚
β”‚ 1 for batch in valid_dataloader: β”‚
β”‚ ❱ 2 β”‚ outputs = model.generate(input_ids = batch[β€œinput_ids”], β”‚
β”‚ 3 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ attention_mask = batch[β€œattention_mask”], β”‚
β”‚ 4 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ output_attentions = True, β”‚
β”‚ 5 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ num_beams = 3, β”‚
β”‚ β”‚
β”‚ /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:952 in generate β”‚
β”‚ β”‚
β”‚ 949 β”‚ β”‚ ) β”‚
β”‚ 950 β”‚ β”‚ try: β”‚
β”‚ 951 β”‚ β”‚ β”‚ if not isinstance(peft_config, PromptLearningConfig): β”‚
β”‚ ❱ 952 β”‚ β”‚ β”‚ β”‚ outputs = self.base_model.generate(**kwargs) β”‚
β”‚ 953 β”‚ β”‚ β”‚ else: β”‚
β”‚ 954 β”‚ β”‚ β”‚ β”‚ if β€œinput_ids” not in kwargs: β”‚
β”‚ 955 β”‚ β”‚ β”‚ β”‚ β”‚ raise ValueError("input_ids must be provided for Peft model generati β”‚
β”‚ β”‚
β”‚ /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in decorate_context β”‚
β”‚ β”‚
β”‚ 112 β”‚ @functools.wraps(func) β”‚
β”‚ 113 β”‚ def decorate_context(*args, **kwargs): β”‚
β”‚ 114 β”‚ β”‚ with ctx_factory(): β”‚
β”‚ ❱ 115 β”‚ β”‚ β”‚ return func(*args, **kwargs) β”‚
β”‚ 116 β”‚ β”‚
β”‚ 117 β”‚ return decorate_context β”‚
β”‚ 118 β”‚
β”‚ β”‚
β”‚ /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1490 in generate β”‚
β”‚ β”‚
β”‚ 1487 β”‚ β”‚ β”‚ β”‚ **model_kwargs, β”‚
β”‚ 1488 β”‚ β”‚ β”‚ ) β”‚
β”‚ 1489 β”‚ β”‚ β”‚ # 13. run beam search β”‚
β”‚ ❱ 1490 β”‚ β”‚ β”‚ return self.beam_search( β”‚
β”‚ 1491 β”‚ β”‚ β”‚ β”‚ input_ids, β”‚
β”‚ 1492 β”‚ β”‚ β”‚ β”‚ beam_scorer, β”‚
β”‚ 1493 β”‚ β”‚ β”‚ β”‚ logits_processor=logits_processor, β”‚
β”‚ β”‚
β”‚ /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2749 in beam_search β”‚
β”‚ β”‚
β”‚ 2746 β”‚ β”‚ β”‚ β”‚
β”‚ 2747 β”‚ β”‚ β”‚ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) β”‚
β”‚ 2748 β”‚ β”‚ β”‚ β”‚
β”‚ ❱ 2749 β”‚ β”‚ β”‚ outputs = self( β”‚
β”‚ 2750 β”‚ β”‚ β”‚ β”‚ **model_inputs, β”‚
β”‚ 2751 β”‚ β”‚ β”‚ β”‚ return_dict=True, β”‚
β”‚ 2752 β”‚ β”‚ β”‚ β”‚ output_attentions=output_attentions, β”‚
β”‚ β”‚
β”‚ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl β”‚
β”‚ β”‚
β”‚ 1498 β”‚ β”‚ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks β”‚
β”‚ 1499 β”‚ β”‚ β”‚ β”‚ or _global_backward_pre_hooks or _global_backward_hooks β”‚
β”‚ 1500 β”‚ β”‚ β”‚ β”‚ or _global_forward_hooks or _global_forward_pre_hooks): β”‚
β”‚ ❱ 1501 β”‚ β”‚ β”‚ return forward_call(*args, **kwargs) β”‚
β”‚ 1502 β”‚ β”‚ # Do not call functions when jit is used β”‚
β”‚ 1503 β”‚ β”‚ full_backward_hooks, non_full_backward_hooks = , β”‚
β”‚ 1504 β”‚ β”‚ backward_pre_hooks = β”‚
β”‚ β”‚
β”‚ in forward:105 β”‚
β”‚ β”‚
β”‚ 102 β”‚ β”‚ β”‚
β”‚ 103 β”‚ β”‚ # Merge distribution β”‚
β”‚ 104 β”‚ β”‚ original_word_pro = torch.softmax(lm_logits, dim = -1) * (1 - p_copy) #[batch, s β”‚
β”‚ ❱ 105 β”‚ β”‚ copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(bat β”‚
β”‚ 106 β”‚ β”‚ lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions β”‚
β”‚ 107 β”‚ β”‚ β”‚
β”‚ 108 β”‚ β”‚ eps = 1e-7 β”‚
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: β€˜NoneType’ object has no attribute β€˜unsqueeze’

Please help me, I can’t fix this bug

2 Likes

Overwrite this method as follows:

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        res = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values,
            attention_mask,
            head_mask,
            decoder_head_mask,
            cross_attn_head_mask,
            use_cache,
            encoder_outputs,
            **kwargs
        )
        # we need input ids
        res['input_ids'] = input_ids
        return res
1 Like

This solution doesn’t work for me. It raises another error during inference.

466 def forward(
467     self,
468     hidden_states,

(…)
474 output_attentions=False,
475 ):
476 normed_hidden_states = self.layer_norm(hidden_states)
β†’ 477 attention_output = self.SelfAttention(
478 normed_hidden_states,
479 mask=attention_mask,
480 position_bias=position_bias,
481 layer_head_mask=layer_head_mask,
482 past_key_value=past_key_value,
483 use_cache=use_cache,
484 output_attentions=output_attentions,
485 )
486 hidden_states = hidden_states + self.dropout(attention_output[0])
487 outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them

File /apps/Arch/software/PyTorch/2.1.2-foss-2023a-CUDA-12.1.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
β†’ 1518 return self._call_impl(*args, **kwargs)

File /apps/Arch/software/PyTorch/2.1.2-foss-2023a-CUDA-12.1.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don’t have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
β†’ 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /mimer/NOBACKUP/groups/cik_data/FrancescoPeriti/vatpub-application/venv/lib/python3.11/site-packages/transformers/models/mt5/modeling_mt5.py:435, in MT5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions)
432 else:
433 position_bias_masked = position_bias
β†’ 435 scores += position_bias_masked
436 attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
437 scores
438 ) # (batch_size, n_heads, seq_length, key_length)
439 attn_weights = nn.functional.dropout(
440 attn_weights, p=self.dropout, training=self.training
441 ) # (batch_size, n_heads, seq_length, key_length)

RuntimeError: output with shape [5, 6, 1, 1] doesn’t match the broadcast shape [5, 6, 1, 512]

2 Likes