T5 Pointer Generator

Hi everyone, I’m on a text summarization task. The model that I use is mt0-base with MT5ForConditionalGeneration class, I combined this class with a network pointer, here is my code:

class MT5PointerGenerator(MT5ForConditionalGeneration):

_keys_to_ignore_on_load_missing = ["linear_copy.weight", "linear_copy.bias"]

def __init__(self, config):
    super().__init__(config)
    self.linear_copy = nn.Linear(self.model_dim, 1)
    
def forward(self,
            input_ids = None,
            attention_mask = None,
            decoder_input_ids = None,
            decoder_attention_mask = None,
            head_mask = None,
            decoder_head_mask = None,
            cross_attn_head_mask = None,
            encoder_outputs = None,
            past_key_values = None,
            inputs_embeds = None,
            decoder_inputs_embeds = None,
            labels = None,
            use_cache = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None):
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
    if head_mask is not None and decoder_head_mask is None:
        if self.config.num_layers == self.config.num_decoder_layers:
            decoder_head_mask = head_mask
    
    if encoder_outputs is None:
        encoder_outputs = self.encoder(input_ids = input_ids,
                                       attention_mask = attention_mask,
                                       inputs_embeds = inputs_embeds,
                                       head_mask = head_mask,
                                       output_attentions = output_attentions,
                                       output_hidden_states = output_hidden_states,
                                       return_dict = return_dict)
        
    elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
        encoder_outputs = BaseModelOutput(last_hidden_state = encoder_outputs[0],
                                          hidden_states = encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                                          attentions = encoder_outputs[2] if len(encoder_outputs) > 2 else None)
    
    hidden_states = encoder_outputs[0]
    
    if self.model_parallel:
        torch.cuda_set_device(self.decoder.first_device)
        
    if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
        decoder_input_ids = self._shift_right(labels)
    
    if past_key_values is not None:
        assert labels is None, "Decoder should not use cached key value states when training."
        if decoder_input_ids is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]
        if decoder_inputs_embeds is not None:
            decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
    
    if self.model_parallel:
        torch.cuda.set_device(self.decoder.first_device)
        hidden_states = hidden_states.to(self.decoder.first_device)
        if decoder_input_ids is not None:
            decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.decoder.first_device)
        if decoder_attention_mask is not None:
            decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
    
    decoder_outputs = self.decoder(input_ids = decoder_input_ids,
                                   attention_mask = decoder_attention_mask,
                                   inputs_embeds = decoder_inputs_embeds,
                                   past_key_values = past_key_values,
                                   encoder_hidden_states = hidden_states,
                                   encoder_attention_mask = attention_mask,
                                   head_mask = decoder_head_mask,
                                   cross_attn_head_mask = cross_attn_head_mask,
                                   use_cache = use_cache,
                                   output_attentions = output_attentions,
                                   output_hidden_states = output_hidden_states,
                                   return_dict = return_dict)
    sequence_output = decoder_outputs[0]
    
    if self.model_parallel:
        torch.cuda.set_device(self.encoder.first_device)
        self.lm_head = self.lm_head.to(self.encoder.first_device)
        sequence_output = sequence_output.to(self.lm_head.weight.device)
    
    if self.config.tie_word_embeddings:
        sequence_output = sequence_output * (self.model_dim ** -0.5)
    
    lm_logits = self.lm_head(sequence_output)
    
    # Copy distribution
    cross_attentions = decoder_outputs["cross_attentions"][-1]
    cross_attentions = torch.mean(cross_attentions, dim = 1)
    
    # Probability of copying
    p_copy = torch.sigmoid(self.linear_copy(sequence_output))
    
    # Merge distribution
    original_word_pro = torch.softmax(lm_logits, dim = -1) * (1 - p_copy) #[batch, sequence_length, vocab_size]
    copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(batch, target_length, encoder_length)
    lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions*p_copy)
    
    eps = 1e-7
    lm_logits = torch.log(lm_logits + eps)
    
    loss = None
    if labels is not None:
        loss_fct = NLLLoss(ignore_index = -100)
        loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        
    if not return_dict:
        output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
        return ((loss,) + output) if loss is not None else output
    
    return Seq2SeqLMOutput(loss = loss,
                           logits = lm_logits,
                           past_key_values = decoder_outputs.past_key_values,
                           decoder_hidden_states = decoder_outputs.hidden_states,
                           decoder_attentions = decoder_outputs.attentions,
                           cross_attentions = decoder_outputs.cross_attentions,
                           encoder_last_hidden_state = encoder_outputs.last_hidden_state,
                           encoder_hidden_states = encoder_outputs.hidden_states,
                           encoder_attentions = encoder_outputs.attentions)

model = MT5PointerGenerator(…)
When I use model to train and calculate loss there is no problem, however when i use model.generate(…) method I get this error:

in :2 │
│ │
│ 1 for batch in valid_dataloader: │
│ ❱ 2 │ outputs = model.generate(input_ids = batch[“input_ids”], │
│ 3 │ │ │ │ │ │ │ attention_mask = batch[“attention_mask”], │
│ 4 │ │ │ │ │ │ │ output_attentions = True, │
│ 5 │ │ │ │ │ │ │ num_beams = 3, │
│ │
│ /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:952 in generate │
│ │
│ 949 │ │ ) │
│ 950 │ │ try: │
│ 951 │ │ │ if not isinstance(peft_config, PromptLearningConfig): │
│ ❱ 952 │ │ │ │ outputs = self.base_model.generate(**kwargs) │
│ 953 │ │ │ else: │
│ 954 │ │ │ │ if “input_ids” not in kwargs: │
│ 955 │ │ │ │ │ raise ValueError("input_ids must be provided for Peft model generati │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1490 in generate │
│ │
│ 1487 │ │ │ │ **model_kwargs, │
│ 1488 │ │ │ ) │
│ 1489 │ │ │ # 13. run beam search │
│ ❱ 1490 │ │ │ return self.beam_search( │
│ 1491 │ │ │ │ input_ids, │
│ 1492 │ │ │ │ beam_scorer, │
│ 1493 │ │ │ │ logits_processor=logits_processor, │
│ │
│ /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2749 in beam_search │
│ │
│ 2746 │ │ │ │
│ 2747 │ │ │ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) │
│ 2748 │ │ │ │
│ ❱ 2749 │ │ │ outputs = self( │
│ 2750 │ │ │ │ **model_inputs, │
│ 2751 │ │ │ │ return_dict=True, │
│ 2752 │ │ │ │ output_attentions=output_attentions, │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = ,
│ 1504 │ │ backward_pre_hooks =
│ │
│ in forward:105 │
│ │
│ 102 │ │ │
│ 103 │ │ # Merge distribution │
│ 104 │ │ original_word_pro = torch.softmax(lm_logits, dim = -1) * (1 - p_copy) #[batch, s │
│ ❱ 105 │ │ copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(bat │
│ 106 │ │ lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions │
│ 107 │ │ │
│ 108 │ │ eps = 1e-7 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: ‘NoneType’ object has no attribute ‘unsqueeze’

Please help me, I can’t fix this bug

Overwrite this method as follows:

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        res = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values,
            attention_mask,
            head_mask,
            decoder_head_mask,
            cross_attn_head_mask,
            use_cache,
            encoder_outputs,
            **kwargs
        )
        # we need input ids
        res['input_ids'] = input_ids
        return res