Flax - core dump when starting training

Trying to follow the instructions for training an Roberta-base mlm-model, as described here:

Everything is easy to follow until the start of training. Immediately ends in a core dump, with this error message:

ttcmalloc: large alloc 500236124160 bytes == (nil) @  0x7f51b5df3680 0x7f51b5e13ff4 0x7f51b590a309 0x7f51b590bfb9 0x7f51b590c056 0x7f4e5cc6a659 0x7f4e526a0954 0x7f51b5fe7b8a 0x7f51b5fe7c91 0x7f51b5d46915 0x7f51b5fec0bf 0x7f51b5d468b8 0x7f51b5feb5fa 0x7f51b5bbb34c 0x7f51b5d468b8 0x7f51b5d46983 0x7f51b5bbbb59 0x7f51b5bbb3da 0x67299f 0x682dcb 0x684321 0x5c3cb0 0x5f257d 0x56fcb6 0x56822a 0x5f6033 0x56ef97 0x5f5e56 0x56a136 0x5f5e56 0x569f5e
terminate called after throwing an instance of 'std::bad_alloc'
  what():  std::bad_alloc
https://symbolize.stripped_domain/r/?trace=7f51b5c2918b,7f51b5c2920f&map=
*** SIGABRT received by PID 14088 (TID 14088) on cpu 95 from PID 14088; stack trace: ***
PC: @     0x7f51b5c2918b  (unknown)  raise
    @     0x7f4f86e6d800        976  (unknown)
    @     0x7f51b5c29210  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f51b5c2918b,7f4f86e6d7ff,7f51b5c2920f&map=2a762cd764e70bc90ae4c7f9747c08d7:7f4f79f2b000-7f4f871ac280
E0628 16:55:18.669807   14088 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
E0628 16:55:18.669833   14088 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
E0628 16:55:18.669843   14088 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0628 16:55:18.669852   14088 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
E0628 16:55:18.669864   14088 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0628 16:55:18.669874   14088 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0628 16:55:18.669881   14088 coredump_hook.cc:525] RAW: Discarding core.
E0628 16:55:18.673655   14088 process_state.cc:771] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Any ideas about what is causing this?

OOM for TPU cores; try with smaller batch sizes (1 for starters) or reduce data/model size

Thanks for the feedback. The current setup is a RoBERTa-base model (config loaded from Huggingface), training on the Norwegian OSCAR dataset from Huggingface. Everything according to the tutorial. This should run fine with a batch size of 128. I also tried setting it to 1. Still the exact same error. The error happens almost instantly after running the run_mlm_flax.py-script, which I find a bit strange.

It is a TPU VM Architecture with a v3-8 running v2-alpha. I just tried rebooting the tpu. Same result. Basic jac/tpu-test, ie “import jax;jax.device_count();jax.numpy.add(1,1)” all gives good result, showing the tpu is available.

This is the config.json:

{
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

Sorry then, I have no idea what bug is there with HF :sweat_smile: I had the same problem and I resolved by reducing the memory usage of the model by cropping some of my data. no idea with your case though.

I had this problem too. What helped me was

Also, if it’s relevant, I’ve wrote down this list of things I did which allowed me to train directly on the TPU VM: