I’m doing research on gene segmentation, and I’m currently trying to implement a BERT network.
I have a database of 350 GB database (~300M proteins) of unlabelled proteins, where each protein is a sequence of letters with various lengths, and each letter correspond to an amino acid. There are only about 20 different amino acids, so they are easy to tokenize. And should be easy to use for a BERT model to predict.
(here is an example of 3 such proteins)
DGQPEIPAGRGEHPQGIPEDTSPNDIMSEVDLQMEFATRIAMESQLGDTLKSRLRISNAQTTDTGNYTCQPTTASSASVLVHVINGE
AGQLWLSIGLISGDDSLDTREGVDLVLKCRFTEHYDSTDFTFYWARWTCCPTLFENVAIGDVQLNSNYRLDFRPSRGIYDLQIKNTSYNRDNGRFECRIKAKGTGADVHQEFYNLTVLTAPHPPMVTPGNLAVATEEKPLELTCSSIGGSPDPMITWYREGSTVPLQSYALKGGSKNHYTNATLQIVPRRADDGAKYKCVVWNRAMPEGHMLETSVTLNVNYYPRVEVGPQNPLKVERDHVAKLDCRVDAKPMVSNVRWSRNGQYVSATPTHTIYRVNRHHAGKYTCSADNGLGKTGEKDIVLDVLYPPIVFIESKTHEAEEGETVLIRCNVTANPSPINVEWLKEGAPDFRYTGELLTLGSVRAEHAGNYICRSVNIMQPFSSKRVEGVGNSTVALLVRHRPGQAYITPNKPVVHVGNGVTLTCSANPPGWPVPQYRWFRDMDGDIGNTQKILAQGPQYSIPKAHLGSEGKYHCHAVNELGIGKIATIILEVHQPPQFLAKLQQHMTRRVGDVDYAVTCSAKGKPTPQIRWIKDGTEILPTRKMFDIRTTPTDAGGGVVAVQSILRFRGKARPNGNQLLPNDRGLYTCLYENDVNSANSSMHLRIEHEPIVIHQYNKVAYDLRESAEVVCRVQAYPKPEFQWQYGNNPSPLTMSSDGHYEISTRMENNDVYTSILRIAHLQHSDYGEYICRAVNPLDSIRAPIRLQPKGSPEKPTNLKILEVGHNYAVLNWTPGFNGGFMSTKYLVSYRRVATPREQTLSDCSGNGYIPSYQISSSSSNSNHEWIEFNCFKENPCKLAPLDQHQSYMFKVYALNSKGTSGYSNEILATTKVSKIPPPLHVSYDPNSHVLGINVAATCLSLIAVVESLVTRDATVPMWEIVETLTLLPSGSETTFKEAIINHVSRPAHYTTATTSGRSLGVGGGSHLGEDRTMALAETAGPGPVVRVKLCLRSNHEHCGAYANAEIGKSYMPHKSSMTTSALVAIIIASLSFVVFLGLLYAFCHCRRKHAAKKESSSVGGGVGGGNANATANPGSTGAKEYDLDLDASRRPSLSQDPQQSQQQPPPPPPYYPTGTLDSKDIGNGNGGMELTLTALHDPDEQLNMQQQQHHSNHGQYQQPKAILGIYGGVAGSGGNNSGGQHPHSNGYGYHVTSAIGVDSDSYQVLPSVANSAAGSHGHGSGHGHGLGAGEXPLEATPPTCNISGGSSSNSGINPMQQQHSARANLTNQPTIATASSTNNYNNHLNNTNIAHTTNNTNNCTTLKRGHLGNRERERERCQVTAATAATTTALATTITTTSRNAKAATTTTTLAITGSSSNSNENNYSNARDLSQEALWRLQMATAQSQQIYVERPPSAFSGLVDYSGYSPHIPTVTSSLSQQSFSPTQQLAPHEMLQAAQRYGTLRKSGKQPPLPPQRKDMQQQAKPPQQMADIIQDLAN
MKQINAASALCGQLKQHENRAGPSNLGNVISQILLCKQFTPDFNEEELCSITKDSQDIAVLLAEMQEYMPQHEAYLERNAALDTTGPWQAKRRQNYICKNMSLLCCVS
Furthermore, I have a few small databases with 800-25000 proteins of labelled proteins, meaning for each amino acid in each protein it contains 1 of three labels (0 = postive, 1 = negative, 2 = unknown)
So in this case my data would look like (2 small proteins shown here):
DGQPEIPAGRGEHPQGI, 22222111111222222
YLERNAALDTTGPWQA, 00002222222222222
So far I have made a standard BERT model, with a linear layer on top (see code below) that I can use in a cross entropy loss function to predict the 15% masked amino acids in the proteins. It should be noted here that I do not start each input sequence with [cls] and end with [sep] since each sequence is only one protein I figured that I could do without this (Though maybe this is wrong and should still be included?)
class BERTseq(nn.Module):
def __init__(self, bert: BERT, vocab_size):
super().__init__()
self.bert = bert
self.linear = nn.Linear(self.bert.hidden, vocab_size)
def forward(self, x, segment_label):
x = self.bert(x, segment_label)
return self.linear(x)
Now this part seems to work rather well. After training for a few hours I get 96% accuracy on testing data when predicting these proteins.
So now I want to try and use this pretrained network for the segmentation task, but I must admit I’m uncertain how exactly I should do this, my initial idea was to remove the last linear layer from before and insert a new linear layer with size nn.Linear(self.bert.hidden, 2), since I have 2 segmentation classes I want out.
However in this case I don’t know whether I should still use the masked approach from before? or should I just use the full label segmentation as input?, Also should I keep the old linear layer at the end and add another layer on top of it, or replace it with the new linear layer?
So far I have been trying to replace it and use the segmentation labels directly in BERT, but that doesn’t seem to work.