Large max differences between single input processing and batching with Bert and T5

We found that BertModel and T5EncoderModel with our Rostlab/prot_bert_bfd Rostlab/prot_t5_xl_uniref50 models produce partially very different embeddings when using batching vs. when processing each input independently. The difference becomes significantly larger when using fp16 instead of fp32. I’m wondering, is this just the expected numerical instability, some problem with our setup or a bug in transformers?

Through the script below I get the following output, where max is the maximum absolute difference between the dimensions:

Rostlab/prot_bert_bfd half=True
max: 0.0888671875              mean: 3.6954879760742188e-06
max: 0.004638671875            mean: -5.960464477539062e-07
max: 0.02490234375             mean: 1.3113021850585938e-06
max: 0.0061798095703125        mean: 4.76837158203125e-07
max: 0.0380859375              mean: -1.2516975402832031e-06
Rostlab/prot_bert_bfd half=False
max: 7.155537605285645e-05     mean: -6.981906164327256e-10
max: 3.3676624298095703e-06    mean: -5.9122234885578e-10
max: 1.5795230865478516e-05    mean: 6.151099629647661e-10
max: 7.0035457611083984e-06    mean: -9.53980561213541e-10
max: 5.263090133666992e-05     mean: 3.6568654770974263e-09
Rostlab/prot_t5_xl_uniref50 half=True
max: 0.003173828125            mean: -5.364418029785156e-07
max: 0.001953125               mean: 2.384185791015625e-07
max: 0.001708984375            mean: 2.980232238769531e-07
max: 0.00347900390625          mean: 5.0067901611328125e-06
max: 0.0029296875              mean: 1.1920928955078125e-07
Rostlab/prot_t5_xl_uniref50 half=False
max: 1.6689300537109375e-06    mean: 1.7499914017893303e-10
max: 1.0281801223754883e-06    mean: 1.4463226449823452e-10
max: 6.556510925292969e-07     mean: -9.230356930178818e-12
max: 2.6971101760864258e-06    mean: 1.4639487400103235e-09
max: 9.238719940185547e-07     mean: -9.745625834112204e-11

Code:

import re
from itertools import zip_longest
from typing import List, Union

import numpy
import torch
from numpy import ndarray
from transformers import T5EncoderModel, T5Tokenizer, BertModel, BertTokenizer

# These are protein sequences; Each character is one amino acid which is equivalent to one word
sequences = [
    "MFLFFCAATILCLWVNSGGAVVVSNETLVFCEPVSYPYSLQVLRSFSQRVNLRTKRAVIIDAWSFAYQISTNSLNVNGWYVNFTSPLGWSYPNGKPFGIVLGSDAMMRASQSIFTYDVISYVGQRPNLDCQINDLVNGGLKNWYSTVRVDNCGNYPCHGGGKPGCSIGQPYMANGVCTRVLSTTQSPGIQYEIYSGQDYAVYQITPYTQYTVTMPSGTSGYCQQTPLYVECGSWTPYRVHTYGCDKVTQSCKYTISSDWVVAFKSKITAVTLPSDLKVPVVQKVTKRLGVTSPDYFWLIKQAYQYLSQATISPNYALFSALSNSLYQQSLVLTDLCYGSPFFMARECYNNALYLPDAVFTTLFSILFSWDYQVNYPVNNVLQSNETFLQLPTTGYLGQTVSQGRMLNLFKDAIVFLDFYDTKFYRTNDGPGGDIFAVVVKQAPVIAYSAFRIEQQTGDYLAVKCNGVTQATLAPHSSRVVLLARHMSMWSIAAANSTTIYCPIYTLTQFGSLDISTSWYFHTLAQPSGPIQQVSMPLLSTAAAGVYMYPMVEHWVTLLTQTQDVYQPSMFNMGVNKSVTLTTQLQAYAQVYTAWFLSILYTRLPESRRLTLGAQLTPFIQALLSFRQADIDATDVDTVARYNVLSLMWGRKYAAVSYNQLPEWSYPLFKGGVGDSMWFRKEISCTTQNPSTSSHFPFIAGYLDFLDYKYIPKYKDVACPTTMVTPTLLQVYETPQLFVIIVQCVSTTYSWYPGLRNPHTIYRSYKLGTICILVPYSSPTSVYSSFGFFFQSALTIPIVQTTDDILPGCVGFVQDSVFTPCHPSGCPVRNSYDNYIICPGSSASNYTLRNYHRTTIPVTNVPIDEVPLQLEIPTVSLTSYELKQSESVLLQDIEGGIVVDHNTGSIWYPDGQAYDVSFYVSVIIRYAPPKLELPSTLANFTSCLDYICFGNQQCRGEAQTFCTSMDYFEQVFNKSLTSLIIALQDLHYVLKLVLPETTLELTEDTRRRRRAVDEFSDTISLLSESFERFMSPASQAYMANMMWWDEAFDGISLPQRTGSILSRTPSLSSTSSWRSYSSRTPLISNVKTPKTTFNVKLSMPKLPKASTLSTIGSVLSSGLSIASLGLSIFSIIEDRRVTELTQQQIMALENQITILTDYTEKNFKEIQSFLNTLGQQVQDFSQQVTLSLQQLFNGLEQITQQLDKSIYYVMAVQQYATYMSSFVNQLNELSQAVYKTQDMYITCIHSLQSGVLSPNCITPAQMFHLYQVAKNLSGECQPIFSEREISRFYSLPLVTDAMVHNDTYWFSWSIPITCSNILGSVYKVQPGYIVNPHHPTSLQYDVPTHVVTSNAGALIFDEHYCDRYNQVYLCTKSAFDLAESSYLTMLYSNQTDNSSLTFHPEPRPVPCVYLSASALYCYYSDECHQCVIAVGNCTNRTVTYENYTYSIMDPQCRGFDQVTISSPIAIGADFTALPSRPPLPLHLSYVNVTFNVTLPNGVNWTDLVLDYSFKDKVYEISKNITQLHEQILQVSNWASGWFQRLRDFLYGLIPAWITWLTLGFSLFSILISGVNIILFFEMNGKVKKS",
    "MKKLFVVLVVMPLIYGDNFPCSKLTNRTIGNHWNLIETFLLNYSSRLPPNSDVVLGDYFPTVQPWFNCIRNNSNDLYVTLENLKALYWDYAKETITWNHKQRLNVVVNGYPYSITVTTTRNFNSAEGAIICICKGSPPTTTTESSLTCNWGSECRLNHKFPICPSNSESNCGNMLYGLQWFADE",
    "MAHLCTQQARPMEWNTFFLVILIIIIKSTTPQITQRPPVENISTYHADWDTPLYTHPSNCRDDSFVPIRPAQLRCPHEFEDINKGLVSVPTKIIHLPLSVTSVSAVASGHYLHRVTYRVTCSTSFFGGQTIEKTILEAKLSRQEATDEASKDHEYPFFPEPSCIWMKNNVHKDITHYYKTPKTVSVDLYSRKFLNPDFIEGVCTTSPCQTHWQGVYWVGATPKAHCPTSETLEGHLFTRTHDHRVVKAIVAGHHPWGLTMACTVTFCGAEWIKTDLGDLIQVTGPGGTGKLTPKKCVNADVQMRGATDDFSYLNHLITNMAQRTECLDAHSDITASGKISSFLLSKFRPSHPGPGKAHYLLNGQIMRGDCDYEAVVSINYNSAQYKTVNNTWKSWKRVDNNTDGYDGMIFGDKLIIPDIEKYQSVYDSGMLVQRNLVEVPHLSIVFVSNTSDLSTNHIHTNLIPSDWSFHWSIWPSLSGMGVVGGAFLLLVLCCCCKASPPIPNYGIPMQQFSRSQTV",
    "MIVLVTCLLLLCSYHTVLSTTNNECIQVNVTQLAGNENLIRDFLFSNFKEEGSVVVGGYYPTEVWYNCSRTAWTTAFQYFNNIHAFYFVMEAMENSTGNARGKPLLFHVHGEPVSVIIYISAYRDDVQQRPLLKHGLVCITKNRHINYEQFTSNQWNSTCTGADRKIPFSVIPTDNGTKIYGLEWNDDFVTAYISGRSYHLNINTNWFNNVTLLYSRSSTATWEYSAAYAYQGVSNFTYYKLNNTNGLKTYELCEDYEHCTGYATNVFAPTSGGYIPDGFSFNNWFLLTNSSTFVSGRFV",
    "MLVKSLFIVTILFALCSANLYDNHAYVYYYQSAFRPPNGWHLHGGAYKVVNVSSERNNAGGADTCTAGAIYWSKNFSASSVAMTAPLSGMEWSTSQFCTAHCNFTDITVFVTHCFKAGSGNCPLTGLIPKDHIRISAMKKSGSGPSDLFYNLTVSVTKYPKFMSLQCVNNLTSVYLNGYLVFTSNETKDVMAAGVHFKAGGPITYKVMREVKAMAYFINGTAQDIILCDGSPRGLLACQYNTGNFSDGFYPFTNSSLVKKKFFVYRETSISTTLVLYNLTFSNVSNASPNKGGVYTIDLHQTQTAQDGYYNFDFSFLSSSVNVESNFMYGSYHPKCSFRPENINNGLWFNSISISLAYGPLQGGCKQSVFNNRATCCYAYSYNGPSLCKGVYSGELQQTFECGLLVYVTKSDGSRIQTASKPPVITQHNYNNITLNTCVDYNIYGSTGQGFITNVTDHAANYNYLASGGLAILDASGAIDIFVVQGEYGLNYYKVNPCEDVNQQFVVSGGKLVGILTSRNETDSQLLENQFYIKLTNGTRRSRR",
]


def embed_batch(
    batch: List[str],
    model: Union[BertModel, T5EncoderModel],
    tokenizer: Union[BertTokenizer, T5Tokenizer],
) -> List[ndarray]:
    seq_lens = [len(seq) for seq in batch]
    # Remove rare amino acids
    batch = [re.sub(r"[UZOB]", "X", sequence) for sequence in batch]
    # Every amino acid is a "word"
    batch = [" ".join(list(seq)) for seq in batch]

    ids = tokenizer.batch_encode_plus(batch, add_special_tokens=True, padding="longest")

    tokenized_sequences = torch.tensor(ids["input_ids"]).to(model.device)
    attention_mask = torch.tensor(ids["attention_mask"]).to(model.device)

    with torch.no_grad():
        embeddings = model(input_ids=tokenized_sequences, attention_mask=attention_mask)

    embeddings = embeddings[0].cpu().numpy()

    trimmed = []
    for seq_num, seq_len in zip_longest(range(len(embeddings)), seq_lens):
        trimmed.append(embeddings[seq_num][:seq_len])
    return trimmed


def main():
    device = torch.device("cuda")
    bert = (BertModel, BertTokenizer, "Rostlab/prot_bert_bfd")
    t5 = (T5EncoderModel, T5Tokenizer, "Rostlab/prot_t5_xl_uniref50")

    results = []
    for model_class, tokenizer_class, model_name in [bert, t5]:
        for half in [True, False]:
            print(f"{model_name} half={half}")

            model = model_class.from_pretrained(model_name)

            if half:
                # This here makes the difference
                model = model.half()

            model = model.to(device).eval()
            tokenizer = tokenizer_class.from_pretrained(model_name, do_lower_case=False)

            embeddings_single_sequence = [
                embed_batch([sequence], model, tokenizer)[0] for sequence in sequences
            ]
            embeddings_batched = embed_batch(sequences, model, tokenizer)
            results.append(
                (model_name, half, embeddings_single_sequence, embeddings_batched)
            )

    for (model_name, half, embeddings_single_sequence, embeddings_batched) in results:
        print(f"{model_name} half={half}")
        for a, b in zip(embeddings_single_sequence, embeddings_batched):
            print(
                f"max: {numpy.max(numpy.abs(a - b)):<25} mean: {numpy.mean(a - b):<25}"
            )


if __name__ == "__main__":
    main()

(It also complains that we’re not using the decoder weights, but this expected)

T5Model in fp16 still yield nan with more complex examples · Issue #4586 · huggingface/transformers · GitHub talks about nan values in T5 fp16, but otherwise I couldn’t find anything similar.

Environment info

  • transformers version: 4.5.1
  • Platform: Linux-4.15.0-126-generic-x86_64-with-glibc2.10 (ubuntu 18.04)
  • Python version: 3.8.0
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Nvidia Quadro RTX 8000
  • Using distributed or parallel set-up in script?: no