We found that BertModel
and T5EncoderModel
with our Rostlab/prot_bert_bfd
Rostlab/prot_t5_xl_uniref50
models produce partially very different embeddings when using batching vs. when processing each input independently. The difference becomes significantly larger when using fp16 instead of fp32. I’m wondering, is this just the expected numerical instability, some problem with our setup or a bug in transformers?
Through the script below I get the following output, where max is the maximum absolute difference between the dimensions:
Rostlab/prot_bert_bfd half=True
max: 0.0888671875 mean: 3.6954879760742188e-06
max: 0.004638671875 mean: -5.960464477539062e-07
max: 0.02490234375 mean: 1.3113021850585938e-06
max: 0.0061798095703125 mean: 4.76837158203125e-07
max: 0.0380859375 mean: -1.2516975402832031e-06
Rostlab/prot_bert_bfd half=False
max: 7.155537605285645e-05 mean: -6.981906164327256e-10
max: 3.3676624298095703e-06 mean: -5.9122234885578e-10
max: 1.5795230865478516e-05 mean: 6.151099629647661e-10
max: 7.0035457611083984e-06 mean: -9.53980561213541e-10
max: 5.263090133666992e-05 mean: 3.6568654770974263e-09
Rostlab/prot_t5_xl_uniref50 half=True
max: 0.003173828125 mean: -5.364418029785156e-07
max: 0.001953125 mean: 2.384185791015625e-07
max: 0.001708984375 mean: 2.980232238769531e-07
max: 0.00347900390625 mean: 5.0067901611328125e-06
max: 0.0029296875 mean: 1.1920928955078125e-07
Rostlab/prot_t5_xl_uniref50 half=False
max: 1.6689300537109375e-06 mean: 1.7499914017893303e-10
max: 1.0281801223754883e-06 mean: 1.4463226449823452e-10
max: 6.556510925292969e-07 mean: -9.230356930178818e-12
max: 2.6971101760864258e-06 mean: 1.4639487400103235e-09
max: 9.238719940185547e-07 mean: -9.745625834112204e-11
Code:
import re
from itertools import zip_longest
from typing import List, Union
import numpy
import torch
from numpy import ndarray
from transformers import T5EncoderModel, T5Tokenizer, BertModel, BertTokenizer
# These are protein sequences; Each character is one amino acid which is equivalent to one word
sequences = [
"MFLFFCAATILCLWVNSGGAVVVSNETLVFCEPVSYPYSLQVLRSFSQRVNLRTKRAVIIDAWSFAYQISTNSLNVNGWYVNFTSPLGWSYPNGKPFGIVLGSDAMMRASQSIFTYDVISYVGQRPNLDCQINDLVNGGLKNWYSTVRVDNCGNYPCHGGGKPGCSIGQPYMANGVCTRVLSTTQSPGIQYEIYSGQDYAVYQITPYTQYTVTMPSGTSGYCQQTPLYVECGSWTPYRVHTYGCDKVTQSCKYTISSDWVVAFKSKITAVTLPSDLKVPVVQKVTKRLGVTSPDYFWLIKQAYQYLSQATISPNYALFSALSNSLYQQSLVLTDLCYGSPFFMARECYNNALYLPDAVFTTLFSILFSWDYQVNYPVNNVLQSNETFLQLPTTGYLGQTVSQGRMLNLFKDAIVFLDFYDTKFYRTNDGPGGDIFAVVVKQAPVIAYSAFRIEQQTGDYLAVKCNGVTQATLAPHSSRVVLLARHMSMWSIAAANSTTIYCPIYTLTQFGSLDISTSWYFHTLAQPSGPIQQVSMPLLSTAAAGVYMYPMVEHWVTLLTQTQDVYQPSMFNMGVNKSVTLTTQLQAYAQVYTAWFLSILYTRLPESRRLTLGAQLTPFIQALLSFRQADIDATDVDTVARYNVLSLMWGRKYAAVSYNQLPEWSYPLFKGGVGDSMWFRKEISCTTQNPSTSSHFPFIAGYLDFLDYKYIPKYKDVACPTTMVTPTLLQVYETPQLFVIIVQCVSTTYSWYPGLRNPHTIYRSYKLGTICILVPYSSPTSVYSSFGFFFQSALTIPIVQTTDDILPGCVGFVQDSVFTPCHPSGCPVRNSYDNYIICPGSSASNYTLRNYHRTTIPVTNVPIDEVPLQLEIPTVSLTSYELKQSESVLLQDIEGGIVVDHNTGSIWYPDGQAYDVSFYVSVIIRYAPPKLELPSTLANFTSCLDYICFGNQQCRGEAQTFCTSMDYFEQVFNKSLTSLIIALQDLHYVLKLVLPETTLELTEDTRRRRRAVDEFSDTISLLSESFERFMSPASQAYMANMMWWDEAFDGISLPQRTGSILSRTPSLSSTSSWRSYSSRTPLISNVKTPKTTFNVKLSMPKLPKASTLSTIGSVLSSGLSIASLGLSIFSIIEDRRVTELTQQQIMALENQITILTDYTEKNFKEIQSFLNTLGQQVQDFSQQVTLSLQQLFNGLEQITQQLDKSIYYVMAVQQYATYMSSFVNQLNELSQAVYKTQDMYITCIHSLQSGVLSPNCITPAQMFHLYQVAKNLSGECQPIFSEREISRFYSLPLVTDAMVHNDTYWFSWSIPITCSNILGSVYKVQPGYIVNPHHPTSLQYDVPTHVVTSNAGALIFDEHYCDRYNQVYLCTKSAFDLAESSYLTMLYSNQTDNSSLTFHPEPRPVPCVYLSASALYCYYSDECHQCVIAVGNCTNRTVTYENYTYSIMDPQCRGFDQVTISSPIAIGADFTALPSRPPLPLHLSYVNVTFNVTLPNGVNWTDLVLDYSFKDKVYEISKNITQLHEQILQVSNWASGWFQRLRDFLYGLIPAWITWLTLGFSLFSILISGVNIILFFEMNGKVKKS",
"MKKLFVVLVVMPLIYGDNFPCSKLTNRTIGNHWNLIETFLLNYSSRLPPNSDVVLGDYFPTVQPWFNCIRNNSNDLYVTLENLKALYWDYAKETITWNHKQRLNVVVNGYPYSITVTTTRNFNSAEGAIICICKGSPPTTTTESSLTCNWGSECRLNHKFPICPSNSESNCGNMLYGLQWFADE",
"MAHLCTQQARPMEWNTFFLVILIIIIKSTTPQITQRPPVENISTYHADWDTPLYTHPSNCRDDSFVPIRPAQLRCPHEFEDINKGLVSVPTKIIHLPLSVTSVSAVASGHYLHRVTYRVTCSTSFFGGQTIEKTILEAKLSRQEATDEASKDHEYPFFPEPSCIWMKNNVHKDITHYYKTPKTVSVDLYSRKFLNPDFIEGVCTTSPCQTHWQGVYWVGATPKAHCPTSETLEGHLFTRTHDHRVVKAIVAGHHPWGLTMACTVTFCGAEWIKTDLGDLIQVTGPGGTGKLTPKKCVNADVQMRGATDDFSYLNHLITNMAQRTECLDAHSDITASGKISSFLLSKFRPSHPGPGKAHYLLNGQIMRGDCDYEAVVSINYNSAQYKTVNNTWKSWKRVDNNTDGYDGMIFGDKLIIPDIEKYQSVYDSGMLVQRNLVEVPHLSIVFVSNTSDLSTNHIHTNLIPSDWSFHWSIWPSLSGMGVVGGAFLLLVLCCCCKASPPIPNYGIPMQQFSRSQTV",
"MIVLVTCLLLLCSYHTVLSTTNNECIQVNVTQLAGNENLIRDFLFSNFKEEGSVVVGGYYPTEVWYNCSRTAWTTAFQYFNNIHAFYFVMEAMENSTGNARGKPLLFHVHGEPVSVIIYISAYRDDVQQRPLLKHGLVCITKNRHINYEQFTSNQWNSTCTGADRKIPFSVIPTDNGTKIYGLEWNDDFVTAYISGRSYHLNINTNWFNNVTLLYSRSSTATWEYSAAYAYQGVSNFTYYKLNNTNGLKTYELCEDYEHCTGYATNVFAPTSGGYIPDGFSFNNWFLLTNSSTFVSGRFV",
"MLVKSLFIVTILFALCSANLYDNHAYVYYYQSAFRPPNGWHLHGGAYKVVNVSSERNNAGGADTCTAGAIYWSKNFSASSVAMTAPLSGMEWSTSQFCTAHCNFTDITVFVTHCFKAGSGNCPLTGLIPKDHIRISAMKKSGSGPSDLFYNLTVSVTKYPKFMSLQCVNNLTSVYLNGYLVFTSNETKDVMAAGVHFKAGGPITYKVMREVKAMAYFINGTAQDIILCDGSPRGLLACQYNTGNFSDGFYPFTNSSLVKKKFFVYRETSISTTLVLYNLTFSNVSNASPNKGGVYTIDLHQTQTAQDGYYNFDFSFLSSSVNVESNFMYGSYHPKCSFRPENINNGLWFNSISISLAYGPLQGGCKQSVFNNRATCCYAYSYNGPSLCKGVYSGELQQTFECGLLVYVTKSDGSRIQTASKPPVITQHNYNNITLNTCVDYNIYGSTGQGFITNVTDHAANYNYLASGGLAILDASGAIDIFVVQGEYGLNYYKVNPCEDVNQQFVVSGGKLVGILTSRNETDSQLLENQFYIKLTNGTRRSRR",
]
def embed_batch(
batch: List[str],
model: Union[BertModel, T5EncoderModel],
tokenizer: Union[BertTokenizer, T5Tokenizer],
) -> List[ndarray]:
seq_lens = [len(seq) for seq in batch]
# Remove rare amino acids
batch = [re.sub(r"[UZOB]", "X", sequence) for sequence in batch]
# Every amino acid is a "word"
batch = [" ".join(list(seq)) for seq in batch]
ids = tokenizer.batch_encode_plus(batch, add_special_tokens=True, padding="longest")
tokenized_sequences = torch.tensor(ids["input_ids"]).to(model.device)
attention_mask = torch.tensor(ids["attention_mask"]).to(model.device)
with torch.no_grad():
embeddings = model(input_ids=tokenized_sequences, attention_mask=attention_mask)
embeddings = embeddings[0].cpu().numpy()
trimmed = []
for seq_num, seq_len in zip_longest(range(len(embeddings)), seq_lens):
trimmed.append(embeddings[seq_num][:seq_len])
return trimmed
def main():
device = torch.device("cuda")
bert = (BertModel, BertTokenizer, "Rostlab/prot_bert_bfd")
t5 = (T5EncoderModel, T5Tokenizer, "Rostlab/prot_t5_xl_uniref50")
results = []
for model_class, tokenizer_class, model_name in [bert, t5]:
for half in [True, False]:
print(f"{model_name} half={half}")
model = model_class.from_pretrained(model_name)
if half:
# This here makes the difference
model = model.half()
model = model.to(device).eval()
tokenizer = tokenizer_class.from_pretrained(model_name, do_lower_case=False)
embeddings_single_sequence = [
embed_batch([sequence], model, tokenizer)[0] for sequence in sequences
]
embeddings_batched = embed_batch(sequences, model, tokenizer)
results.append(
(model_name, half, embeddings_single_sequence, embeddings_batched)
)
for (model_name, half, embeddings_single_sequence, embeddings_batched) in results:
print(f"{model_name} half={half}")
for a, b in zip(embeddings_single_sequence, embeddings_batched):
print(
f"max: {numpy.max(numpy.abs(a - b)):<25} mean: {numpy.mean(a - b):<25}"
)
if __name__ == "__main__":
main()
(It also complains that we’re not using the decoder weights, but this expected)
T5Model in fp16 still yield nan with more complex examples · Issue #4586 · huggingface/transformers · GitHub talks about nan values in T5 fp16, but otherwise I couldn’t find anything similar.
Environment info
-
transformers
version: 4.5.1 - Platform: Linux-4.15.0-126-generic-x86_64-with-glibc2.10 (ubuntu 18.04)
- Python version: 3.8.0
- PyTorch version (GPU?): 1.7.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: Nvidia Quadro RTX 8000
- Using distributed or parallel set-up in script?: no