Best way to use a BERT transformer on each sentence in a document?

So, I’m not quite sure how to put this, but I’ll try my best. Also, the actual model is rather complex, so I’ll simplify it a bit. So, my goal is to predict the sentiment of each sentence in a document from the corpus. Now, my naïve approach was to treat each each sentence as independent from each other, i.e. the order doesn’t matter, thus each sentence can be treated as a separate input in the model. So, my model would look something like this:

class my_model(Bert_model):
def init(self, bert_base_model, args):
# some stuff
self.bert_for_global_context = bert_base_model
self.dense = torch.nn.Linear(self.embed_dim, 3)
self.pooler = BertPooler(config)
def forward(self, input):
# more stuff
global_context_out = self.bert_for_global_context(input)[‘last_hidden_state’]
pooled_output = self.pooler(global_context_out)
logits = self.dense(pooled_output)
loss_fun = CrossEntropyLoss()
loss = loss_fun(logits, labels)
return loss
model = my_model(bert_base_model, args)

training

for step, batch in enumerate(train_dataloader):
input = batch
loss = model(input)

more stuff

Note that the input is of shape [batch size, sentence length]. So, I was thinking about how to model any dependency between the sentences, so I was thinking about modeling the sequence of sentiments using a conditional random field, using this package: pytorch-crf — pytorch-crf 0.7.2 documentation. To do this, I reshaped my dataset so that each input would be of size [batch size, document length, sentence length] . So, my model now looks like this:

class my_model(Bert_model):
def init(self, bert_base_model, args):
# some stuff
self.bert_for_global_context = bert_base_model
self.dense = torch.nn.Linear(self.embed_dim, 3)
self.pooler = BertPooler(config)
self.crf = CRF(3, batch_first=True)
def forward(self, input):
# more stuff
global_context_out = self.bert_for_global_context(input)[‘last_hidden_state’]
pooled_output = self.pooler(global_context_out)
logits = self.dense(pooled_output)
logits = logits.view(batch size, document length, 3)
loss_fun = self.crf
loss = loss_fun(logits, labels)
return loss
model = my_model(bert_base_model, args)

training

for step, batch in enumerate(train_dataloader):
input = batch
input = input.view(batch size*document length, sentence length)
loss = model(input)

more stuff

I hope this makes sense. Anyways, while my new model does work, it takes about 10 times as long. The two major changes are the use of view and CRF. Does it make sense for changing the shape of a tensor using “view” to cause the training to slow down so much? If so, is there a better way to do this this task?

Note: I’m not sure how good pytorch-crf is, so maybe that’s causing the slowdown.