TPU Memory problem when saving model checkpoint

Hi,

I’m encountering problems when trying to save model checkpoints during fine-tuning. Fine-tuning itself works just fine, but as soon as a checkpoint is reached and the model needs to be saved, the TPU returns memory-related errors.

I’m fine-tuning the mBART model using a single v3-8 TPU on a GCP VM. To see whether VM memory could be the problem I tried to use a bigger VM with 52GB of memory; this didn’t help.

Versions:
torch 1.11.0
torch-xla 1.11.0
transformers 4.18.0

The code that causes the error:

raw_dataset = load_dataset(path='parquet',
                               data_files={'train': ['viable_cases_chunk_1.parquet',
                                                     'viable_cases_chunk_2.parquet',
                                                     'viable_cases_chunk_3.parquet'],
                                           'test': ['viable_cases_chunk_4.parquet']})

    # Load the tokenizer and model checkpoint
    tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25")
    model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")

    #tokenizer = AutoTokenizer.from_pretrained("yhavinga/t5-v1.1-base-dutch-cased")
    #model = AutoModelForSeq2SeqLM.from_pretrained("yhavinga/t5-v1.1-base-dutch-cased")

    def preprocess_function(examples):
        model_inputs = tokenizer(examples["description"], max_length=1024, padding="max_length", truncation=True)
        labels = tokenizer(examples["summary"], max_length=128, padding="max_length", truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        model_inputs["attention_mask_labels"] = labels["attention_mask"]

        return model_inputs

    # Tokenize the dataset 
    tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir="./results",
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        weight_decay=0.01,
        save_total_limit=3,
        save_steps=100,
        num_train_epochs=3
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset['train']
    )

    trainer.train()

The error that is shown:

***** Running training *****
  Num examples = 21159
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 7935
  1%|██▍                                                                                                                                                                                               | 100/7935 [01:42<2:13:25,  1.02s/it]Saving model checkpoint to ./results/checkpoint-100
Configuration saved in ./results/checkpoint-100/config.json
2022-04-07 12:11:55.164393: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] StackTrace:
2022-04-07 12:11:55.164475: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] *** Begin stack trace ***
2022-04-07 12:11:55.164488: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]        tensorflow::CurrentStackTrace()
2022-04-07 12:11:55.164498: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]        xla::util::ReportComputationError(tensorflow::Status const&, absl::lts_20211102::Span<xla::XlaComputation const* const>, absl::lts_20211102::Span<xla::Shape const* const>)
2022-04-07 12:11:55.164518: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]        xla::XrtComputationClient::CheckCompileStatus(tensorflow::Status const&, std::vector<xla::ComputationClient::CompileInstance, std::allocator<xla::ComputationClient::CompileInstance> > const&, xla::XrtComputationClient::SessionWork const&)
2022-04-07 12:11:55.164539: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164565: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]        xla::util::MultiWait::Complete(std::function<void ()> const&)
2022-04-07 12:11:55.164579: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164603: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164617: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164631: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]        clone
2022-04-07 12:11:55.164645: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] *** End stack trace ***
2022-04-07 12:11:55.164658: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164672: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] Status: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2022-04-07 12:11:55.164687: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] 2 root error(s) found.
2022-04-07 12:11:55.164701: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 17.66G of 15.98G hbm. Exceeded hbm capacity by 1.67G.
2022-04-07 12:11:55.164715: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164728: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] Total hbm usage >= 17.67G:
2022-04-07 12:11:55.164743: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     reserved         18.00M
2022-04-07 12:11:55.164757: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     program           8.55G
2022-04-07 12:11:55.164768: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     arguments         9.10G
2022-04-07 12:11:55.164781: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164794: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] Output size 2.28G; shares 2.28G with arguments.
2022-04-07 12:11:55.164808: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164822: E tensorflow/compiler/xla/xla_client/xla_util.cc:88] Program hbm requirement 8.55G:
2022-04-07 12:11:55.164835: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     global             4.0K
2022-04-07 12:11:55.164849: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     scoped           18.50M
2022-04-07 12:11:55.164863: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]     HLO temp          8.53G (100.0% utilization: Unpadded (8.41G) Padded (8.41G), 1.4% fragmentation (122.91M))
2022-04-07 12:11:55.164874: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164888: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   Largest program allocations in hbm:
2022-04-07 12:11:55.164901: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.164914: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   1. Size: 976.69M
2022-04-07 12:11:55.164925: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8,128,250027]{1,2,0:T(8,128)}
2022-04-07 12:11:55.164943: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 976.67M
2022-04-07 12:11:55.164954: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Extra memory due to padding: 20.0K (1.0x expansion)
2022-04-07 12:11:55.164969: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.7719 = (f32[8,128]{1,0:T(8,128)}, f32[8,128,250027]{1,2,0:T(8,128)}) fusion(f32[250027]{0:T(1024)} %reduce.2119, f32[8,128,1024]{2,1,0:T(8,128)} %get-tuple-element.2057, f32[8,128]{1,0:T(8,128)} %get-tuple-element.1897, f32[8,128]{1,0:T(8,128)} %ge...
2022-04-07 12:11:55.164985: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165000: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165011: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165023: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   2. Size: 976.67M
2022-04-07 12:11:55.165037: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8,128,250027]{1,0,2:T(8,128)}
2022-04-07 12:11:55.165048: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 976.67M
2022-04-07 12:11:55.165062: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.26.remat_compressed = f32[8,128,250027]{1,0,2:T(8,128)} copy(f32[8,128,250027]{1,2,0:T(8,128)} %get-tuple-element.2059)
2022-04-07 12:11:55.165074: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165088: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165101: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165112: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   3. Size: 512.00M
2022-04-07 12:11:55.165125: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165139: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165150: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.42.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9720, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9818, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9824), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165165: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165176: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165189: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165200: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   4. Size: 512.00M
2022-04-07 12:11:55.165210: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165224: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165235: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.43.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9718, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9772, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9778), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165251: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165262: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165275: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165285: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   5. Size: 512.00M
2022-04-07 12:11:55.165296: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165309: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165320: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.44.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9716, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9712, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9707), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165337: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165346: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165359: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165369: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   6. Size: 512.00M
2022-04-07 12:11:55.165391: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8,16,1024,1024]{2,3,1,0:T(8,128)}
2022-04-07 12:11:55.165400: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165411: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %broadcast.13109.remat9 = f32[8,16,1024,1024]{2,3,1,0:T(8,128)} broadcast(f32[8,1024,1024]{1,2,0:T(8,128)} %get-tuple-element.3652), dimensions={0,2,3}
2022-04-07 12:11:55.165426: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165439: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165450: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165463: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   7. Size: 512.00M
2022-04-07 12:11:55.165474: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165488: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165502: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.45.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9621, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9591, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.5925), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165517: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165528: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165541: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165555: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   8. Size: 512.00M
2022-04-07 12:11:55.165566: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165579: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165593: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.7695 = (f32[128,1024]{1,0:T(8,128)}, f32[128,1024,1024]{1,2,0:T(8,128)}) fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9733, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6066, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6068), kind=kOut...
2022-04-07 12:11:55.165609: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165623: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165634: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165646: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   9. Size: 512.00M
2022-04-07 12:11:55.165660: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165671: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165685: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.36.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9732, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6053, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6055), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165701: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165712: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165724: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165738: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   10. Size: 512.00M
2022-04-07 12:11:55.165752: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
2022-04-07 12:11:55.165766: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 512.00M
2022-04-07 12:11:55.165776: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.37.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9730, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6040, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6042), kind=kOutput, calls=%fused_computat...
2022-04-07 12:11:55.165792: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165801: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165814: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165827: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   11. Size: 128.00M
2022-04-07 12:11:55.165838: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8,1024,4096]{2,1,0:T(8,128)}
2022-04-07 12:11:55.165849: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 128.00M
2022-04-07 12:11:55.165862: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.180.remat4 = f32[8,1024,4096]{2,1,0:T(8,128)} fusion(f32[4096]{0:T(1024)} %p132.2204, f32[8,1024,1024]{2,1,0:T(8,128)} %fusion.499.remat2.1.remat3, f32[8,1024]{1,0:T(8,128)} %get-tuple-element.1808, f32[8,1024]{1,0:T(8,128)} %get-tuple-element.1834...
2022-04-07 12:11:55.165875: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165888: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165901: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.165912: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   12. Size: 64.00M
2022-04-07 12:11:55.165923: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,128,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.165941: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 64.00M
2022-04-07 12:11:55.165955: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.350.remat3 = f32[128,128,1024]{2,1,0:T(8,128)} fusion(f32[128,128,1024]{2,1,0:T(8,128)} %bitcast.10604, bf16[128,128,64]{1,2,0:T(8,128)(2,1)} %bitcast.705, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6198), kind=kOutput, calls=%fused_computatio...
2022-04-07 12:11:55.165972: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.165985: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.165999: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166010: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   13. Size: 64.00M
2022-04-07 12:11:55.166023: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,128,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.166038: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 64.00M
2022-04-07 12:11:55.166052: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.356.remat3 = f32[128,128,1024]{2,1,0:T(8,128)} fusion(f32[128,128,1024]{2,1,0:T(8,128)} %bitcast.10262, bf16[128,128,64]{1,2,0:T(8,128)(2,1)} %bitcast.465, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6138), kind=kOutput, calls=%fused_computatio...
2022-04-07 12:11:55.166067: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166079: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166092: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166105: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   14. Size: 64.00M
2022-04-07 12:11:55.166116: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,128,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.166130: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 64.00M
2022-04-07 12:11:55.166141: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.357.remat3 = f32[128,128,1024]{2,1,0:T(8,128)} fusion(f32[128,128,1024]{2,1,0:T(8,128)} %bitcast.10143, bf16[128,128,64]{1,2,0:T(8,128)(2,1)} %bitcast.425, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6128), kind=kOutput, calls=%fused_computatio...
2022-04-07 12:11:55.166156: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166167: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166180: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166191: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   15. Size: 64.00M
2022-04-07 12:11:55.166204: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,128,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.166218: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 64.00M
2022-04-07 12:11:55.166229: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.358.remat3 = f32[128,128,1024]{2,1,0:T(8,128)} fusion(f32[128,128,1024]{2,1,0:T(8,128)} %bitcast.10046, bf16[128,128,64]{1,2,0:T(8,128)(2,1)} %bitcast.385, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6118), kind=kOutput, calls=%fused_computatio...
2022-04-07 12:11:55.166245: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166256: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166269: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166283: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   16. Size: 64.00M
2022-04-07 12:11:55.166293: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[128,128,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.166307: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 64.00M
2022-04-07 12:11:55.166321: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.360.remat3 = f32[128,128,1024]{2,1,0:T(8,128)} fusion(f32[128,128,1024]{2,1,0:T(8,128)} %bitcast.9992, bf16[128,128,64]{1,2,0:T(8,128)(2,1)} %bitcast.35, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6085), kind=kOutput, calls=%fused_computation....
2022-04-07 12:11:55.166334: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166347: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166358: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166371: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   17. Size: 32.00M
2022-04-07 12:11:55.166382: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8,1024,1024]{2,1,0:T(8,128)}
2022-04-07 12:11:55.166395: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 32.00M
2022-04-07 12:11:55.166406: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.6775 = (f32[8,1024]{1,0:T(8,128)}, f32[8,1024]{1,0:T(8,128)}, f32[8,1024,1024]{2,1,0:T(8,128)}) fusion(f32[8,1024,1024]{2,1,0:T(8,128)} %fusion.484.remat2, f32[]{:T(256)S(6)} %divide.983, u8[1024,1024]{1,0:T(8,128)(4,1)} %fusion.1121, f32[1024]{0:T...
2022-04-07 12:11:55.166422: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166433: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166443: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166457: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   18. Size: 32.00M
2022-04-07 12:11:55.166468: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape: f32[8192,1024]{1,0:T(8,128)}
2022-04-07 12:11:55.166482: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Unpadded size: 32.00M
2022-04-07 12:11:55.166496: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      XLA label: %fusion.2167 = (f32[8192]{0:T(1024)}, f32[8192]{0:T(1024)}, f32[8192,1024]{1,0:T(8,128)}) fusion(f32[8192,1024]{1,0:T(8,128)} %bitcast.9556, f32[8192,1024]{1,0:T(8,128)} %fusion.4.remat2, f32[]{:T(256)S(6)} %copy.4485), kind=kLoop, calls=%fused_computation...
2022-04-07 12:11:55.166509: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Allocation type: HLO temp
2022-04-07 12:11:55.166523: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      ==========================
2022-04-07 12:11:55.166536: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]
2022-04-07 12:11:55.166547: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]   19. Size: 32.00M
2022-04-07 12:11:55.166560: E tensorflow/compiler/xla/xla_client/xla_util.cc:88]      Shape:
Traceback (most recent call last):
  File "legalsum.py", line 115, in <module>
    main()
  File "legalsum.py", line 111, in main
    trainer.train()
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1497, in train
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1628, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1700, in _save_checkpoint
    self.save_model(output_dir, _internal_call=True)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 2087, in save_model
    self._save_tpu(output_dir)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 2158, in _save_tpu
    self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
  File "/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py", line 1375, in save_pretrained
    save_function(shard, os.path.join(save_directory, shard_file))
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 885, in save
    cpu_data = _maybe_convert_to_cpu(data, convert=should_write_data)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 903, in _maybe_convert_to_cpu
    return ToXlaTensorArena(convert_fn, select_fn).transform(data)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 413, in transform
    self._convert()
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 385, in _convert
    self._converted_tensors = self._convert_fn(self._tensors)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 895, in convert_fn
    tensors, devices=[], wait=True, sync_xla_data=True)
RuntimeError: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 17.66G of 15.98G hbm. Exceeded hbm capacity by 1.67G.

Total hbm usage >= 17.67G:
    reserved         18.00M
    program           8.55G
    arguments         9.10G

Output size 2.28G; shares 2.28G with arguments.

Program hbm requirement 8.55G:
    global             4.0K
    scoped           18.50M
    HLO temp          8.53G (100.0% utilization: Unpadded (8.41G) Padded (8.41G), 1.4% fragmentation (122.91M))

  Largest program allocations in hbm:

  1. Size: 976.69M
     Shape: f32[8,128,250027]{1,2,0:T(8,128)}
     Unpadded size: 976.67M
     Extra memory due to padding: 20.0K (1.0x expansion)
     XLA label: %fusion.7719 = (f32[8,128]{1,0:T(8,128)}, f32[8,128,250027]{1,2,0:T(8,128)}) fusion(f32[250027]{0:T(1024)} %reduce.2119, f32[8,128,1024]{2,1,0:T(8,128)} %get-tuple-element.2057, f32[8,128]{1,0:T(8,128)} %get-tuple-element.1897, f32[8,128]{1,0:T(8,128)} %ge...
     Allocation type: HLO temp
     ==========================

  2. Size: 976.67M
     Shape: f32[8,128,250027]{1,0,2:T(8,128)}
     Unpadded size: 976.67M
     XLA label: %fusion.26.remat_compressed = f32[8,128,250027]{1,0,2:T(8,128)} copy(f32[8,128,250027]{1,2,0:T(8,128)} %get-tuple-element.2059)
     Allocation type: HLO temp
     ==========================

  3. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.42.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9720, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9818, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9824), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  4. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.43.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9718, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9772, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9778), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  5. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.44.remat5 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9716, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9712, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9707), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  6. Size: 512.00M
     Shape: f32[8,16,1024,1024]{2,3,1,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %broadcast.13109.remat9 = f32[8,16,1024,1024]{2,3,1,0:T(8,128)} broadcast(f32[8,1024,1024]{1,2,0:T(8,128)} %get-tuple-element.3652), dimensions={0,2,3}
     Allocation type: HLO temp
     ==========================

  7. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.45.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9621, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.9591, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.5925), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  8. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.7695 = (f32[128,1024]{1,0:T(8,128)}, f32[128,1024,1024]{1,2,0:T(8,128)}) fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9733, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6066, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6068), kind=kOut...
     Allocation type: HLO temp
     ==========================

  9. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.36.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9732, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6053, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6055), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  10. Size: 512.00M
     Shape: f32[128,1024,1024]{1,2,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.37.remat3 = f32[128,1024,1024]{1,2,0:T(8,128)} fusion(f32[128,1024,1024]{1,2,0:T(8,128)} %bitcast.9730, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6040, bf16[128,1024,64]{1,2,0:T(8,128)(2,1)} %bitcast.6042), kind=kOutput, calls=%fused_computat...
     Allocation type: HLO temp
     ==========================

  11. Size: 128.00M
     Shape: f32[8,1024,4096]{2,1,0:T(8,128)}
     Unpadded size: 128.00M
     XLA label: %fusion.180.remat4 = f32[8,1024,4096]{2,1,0:T(8,128)} fusion(f32[4096]{0:T(1024)} %p132.2204, f32[8,1024,1024]{2,1,0:T(8,128)} %fusion.499.remat2.1.remat3, f32[8,1024]{1,0:T(8,128)} %get-tuple-element.1808, f32[8,1024]{1,0:T(8,128)} %get-tuple-element.1834...
     Allocation type: HLO temp
     ==========================

  12. Size: 64.00M
     Shape: 
  1%|██▍ 

Any help is appreciated!