Dataloader and bart-large-mnli

Hi, I’m using bart-large-mnli without pipeline, and so using manual pytorch, as explained in the example at the bottom of the HF page.

device = 'cpu'
premise = "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."
candidate_labels = ['mobile', 'website', 'billing', 'account access']

entail_contradiction_logits_matrix = []
entail_logits_matrix = []
output = []
multi_class = True
for label in candidate_labels:
    hypothesis = f'This example is {label}.'
    # run through model pre-trained on MNLI
    x = tokenizer.encode_plus(premise, hypothesis, return_tensors='pt',
                        truncation_strategy='only_first')['input_ids']
                        
    logits = nli_model(x)[0]
    # we throw away "neutral" (dim 1) and take the probability of
    # "entailment" (2) as the probability of the label being true 
    if multi_class:
        entail_contradiction_logits_matrix.append(logits[:,[0,2]])
        probs = logits[:,[0,2]].softmax(dim=1)
        prob_label_is_true = probs[:,1]
        output.append(prob_label_is_true.item())
    else:
        entail_logits_matrix.append(logits[:,[2]])

if multi_class== False:
    entails_tensor = torch.cat(entail_logits_matrix, dim=1)
    probs = entails_tensor.softmax(dim=1)
    output = probs.detach().to('cpu').numpy()

print("premise:", premise)
print("candidate_labels:", candidate_labels)
print(output)


So, everything works fine

##OUTPUT
premise: Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.
candidate_labels: ['mobile', 'website', 'billing', 'account access']
[0.9908875226974487, 0.0016617656219750643, 0.37062010169029236, 0.445667564868927]

Now I want to create a dataloader and classify multiple examples at a time. (I’ve replaced unimportant code with ‘…’)

def generate_data_loader(self, examples):
        '''
        Generate a Dataloader given the input examples

        examples: a list of tuples (text, type_text, lang_text, doc_id, parag_id)
        ''' 

        encoded_text_array = []
        label_type_text_array = []
        label_lang_text_array = []

        max_seq_length = 256 # the maximum length to be considered in input
        batch_size = 64 # the batch size

        # Tokenization 
        for (text, type_text, lang_text, doc_id, parag_id) in examples:

            premise = text
            hypothesis_list = self.generate_hypothesis(...)

            for hypothesis in hypothesis_list:
                encoded_text = self.tokenizer.encode_plus(premise, hypothesis, add_special_tokens=True, truncation_strategy='only_first', return_tensors='pt', max_length=max_seq_length, padding='max_length')['input_ids']

                encoded_text_array.append(encoded_text)
                [...]

        encoded_text_array = torch.stack(encoded_text_array)  
        label_lang_text_array = torch.tensor(..., dtype=torch.long)
        label_type_text_array = torch.tensor(..., dtype=torch.long)

        # Building the TensorDataset
        dataset = TensorDataset(encoded_text_array, label_lang_text_array, 
								label_type_text_array, ...)

        sampler = SequentialSampler

        # Building the DataLoader
        return DataLoader(
                    dataset,  # The training samples.
                    sampler = sampler(dataset), # the adopted sampler
                    batch_size = batch_size) # Trains with this batch size.

and finally

for batch in dataloader: 
            b_encoded_text = batch[0].to(self.device)
            [...]

            b_encoded_text = torch.unbind(b_encoded_text)

			with torch.no_grad():
				self.model.eval()
				logits = self.model(b_encoded_text)[0]

But I get an error.


<ipython-input-36-44a7c2660e53> in model_predictor(self, dataloader)
    208                     # Each batch is classifed
    209                     #logits_subjobj, logits_priv, logits_pos, logits_neg, _ = self.model(b_input_ids, b_input_mask)
--> 210                     logits = self.model(b_encoded_text)[0]
    211                 if multi_class:
    212                     print("logits", logits)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1482             output_attentions=output_attentions,
   1483             output_hidden_states=output_hidden_states,
-> 1484             return_dict=return_dict,
   1485         )
   1486         hidden_states = outputs[0]  # last hidden state

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1185 
   1186             decoder_input_ids = shift_tokens_right(
-> 1187                 input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
   1188             )
   1189 

/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py in shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
     64     Shift input ids one token to the right.
     65     """
---> 66     shifted_input_ids = input_ids.new_zeros(input_ids.shape)
     67     shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
     68     shifted_input_ids[:, 0] = decoder_start_token_id

AttributeError: 'tuple' object has no attribute 'new_zeros'

Apart from the specific error, I’m sure I haven’t understood how to pass the data with the dataloader to the model. How is this done? Thank you in advance