How to set up DistilBertModel to use a bach_size?

Goal: I want to get the [CLS] values, but I am getting an error when I call DistilBertModel.
My code:

import transformers as ppb
m, t, p = (ppb.DistilBertModel, ppb.DistilBertTokenizerFast, 'distilbert-base-uncased')

tokenizer = t.from_pretrained(pretrained_weights, cache_dir=<path>)
model = m.from_pretrained(pretrained_weights, from_tf=True, cache_dir=<path>)
# [beginning of EDIT]
def my_encode(tokenizer, texts, max_length=MAX_LENGTH):
        inputs = tokenizer.batch_encode_plus(texts,
                                             max_length=max_length,
                                             padding='longest',
                                             truncation=True,
                                             return_attention_mask=True,
                                             return_token_type_ids=False,
                                             return_tensors="pt"
                                             )
    return inputs

tokenizer_output = my_encode(tokenizer, pandas_df['raw_text'].tolist())
# [end of EDIT]

I am getting error when I call the model with ‘tokenizer_output’, which is a ‘transformers.tokenization_utils_base.BatchEncoding’:
result = model(**tokenizer_output)

The error is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed exec> in <module>

~/miniconda3/envs/x/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    485             output_attentions=output_attentions,
    486             output_hidden_states=output_hidden_states,
--> 487             return_dict=return_dict,
    488         )
    489 

~/miniconda3/envs/x/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict)
    305 
    306             layer_outputs = layer_module(
--> 307                 x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
    308             )
    309             hidden_state = layer_outputs[-1]

~/miniconda3/envs/x/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, x, attn_mask, head_mask, output_attentions)
    262 
    263         # Feed Forward Network
--> 264         ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
    265         ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)
    266 

~/miniconda3/envs/x/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input)
    213 
    214     def forward(self, input):
--> 215         return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
    216 
    217     def ff_chunk(self, input):

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/modeling_utils.py in apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
   1815         return torch.cat(output_chunks, dim=chunk_dim)
   1816 
-> 1817     return forward_fn(*input_tensors)

~/miniconda3/envs/x/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in ff_chunk(self, input)
    217     def ff_chunk(self, input):
    218         x = self.lin1(input)
--> 219         x = self.activation(x)
    220         x = self.lin2(x)
    221         x = self.dropout(x)

~/miniconda3/envs/x/lib/python3.7/site-packages/torch/nn/functional.py in gelu(input)
   1457     if has_torch_function_unary(input):
   1458         return handle_torch_function(gelu, (input,), input)
-> 1459     return torch._C._nn.gelu(input)
   1460 
   1461 

RuntimeError: [enforce fail at CPUAllocator.cpp:67] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 2826240000 bytes. Error code 12 (Cannot allocate memory)

I believe if call model using batch would solve my problem. But how I can I use batch_size here?
result = model(**tokenizer_output)

Is there another way to get the [CLS] (word representation)?

Thanks in advance!

You did not share how you are building your tokenizer_output, so it’s hard to help you.

You’re right. I just added the tokenizer_output code. Please let me know if you have any hint why I am getting this error. Thanks!

Your tokenizer_output contains the whole dataset, it’s no wonder you get our of ram. You need to split it into smaller chunks to go through your model.

One of my questions was how to set batch_size param in the DistilBertModel.
I implemented a naive code to split my data into batches (please see below). But I am sure there is a better way by using some ready to go API. Any idea?

import torch

def batch_embeddings(tokenizer, texts, batch_size=256, max_length=MAX_LENGTH):
    all_cls_tokens = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        tokenizer_output = tokenizer.batch_encode_plus(batch,
                                             max_length=max_length,
                                             padding='longest',
                                             truncation=True,
                                             return_attention_mask=True,
                                             return_token_type_ids=False,
                                             return_tensors="pt"
                                             )
        cls_token = model(**tokenizer_output)[0][:, 0, :]

        all_cls_tokens = all_cls_tokens + cls_token.tolist()
        # here I would save all_cls_tokens into disk time to time and clean it up. It would avoid getting the same ram error as previous code

    return all_cls_tokens


all_cls_tokens = batch_embeddings(tokenizer, pd_df['text'].tolist(), 1000)

Also I see several issues here since I am doing several conversions:
Initially my data is in a pyspark dataframe (I don’t show it here to simplify the question). I converted it to pandas dataframe, then to a list. Inside method batch_embeddings, I sliced this list (variable called batch) and passed it into method tokenizer.batch_encode_plus, which return a tensor.torch (variable called cls_token). I converted it to a list and append it to the previous result. This is another reason I am sure there is another solution.
What would be a better approach to deal with a lot of data using DistilBertModel?

Please help!
Thanks!

Why are you not using a PyTorch DataLoader to split your data in batches?

What else can he use split the input into batches?