🤗Transformer with Trainer API on TPU VMs and TPU Pods

Hi everyone,

I am trying to achieve the following objectives:

  • Running existing training scripts using :hugs:Transformer with Trainer API into TPU VM (v2-8 or v3-8)
  • Running the set-up for TPU Pods (v2-32)

Training with Trainer API in TPU VM

The existing training script, it can train with Nvidia A100 using the TrainingArguments and Trainer API, to adapt with TPU VM, I did the following:

  1. Install additionally the pytorch and pytorch xla

pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html

  1. Update the existing training codes: adding tpu_num_cores argument, move the training code inside _mp_fn function as below
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
    def _mp_fn(rank, flags):
        training_args = TrainingArguments(
            output_dir=model_dir,
            learning_rate=1e-5, #The initial learning rate for AdamW optimizer. step size taken during training to update the model’s weights
            num_train_epochs=1, #default 3
            per_device_train_batch_size=8, #The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
            per_device_eval_batch_size=8, #The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
            gradient_accumulation_steps=1, #default 1 Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
            save_total_limit=2, # If a value is passed, will limit the total amount of checkpoints. For example, for save_total_limit=5 and load_best_model_at_end, the four last checkpoints will always be retained alongside the best model. When save_total_limit=1 and load_best_model_at_end, it is possible that two checkpoints are saved: the last one and the best one (if they are different).
            evaluation_strategy="no", #Evaluation is done (and logged) every eval_steps.
            eval_steps=100,#Number of update steps between two evaluations 
            save_strategy="no",#Save is done every save_steps.
            logging_steps=100, #Number of update steps between two logs
            remove_unused_columns=False, #Whether or not to automatically remove the columns unused by the model forward method.
            push_to_hub=False, #Whether or not to push the model to the Hub every time the model is saved
            label_names=["labels"], # The list of keys in your dictionary of inputs that correspond to the labels.
            load_best_model_at_end=True,
            dataloader_num_workers=12, #as suggested from the system Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.
            report_to="tensorboard",
            logging_dir=logging_dir,#tensorboard log directory
            tpu_num_cores=8, # define number of TPU cores
            dataloader_drop_last=True #Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not
        )
        if loader_args["model"] == "BERT":
            # Some custom model definition for BERT
        elif loader_args["model"] == "GPT2":
            # Some custom model definition for GPT2
        else:
            print("Wrong model specified..Exiting")
            exit(1)

        # Create Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_ds,
            eval_dataset=test_ds,
            data_collator = DataCollatorWithPadding(tokenizer=tokenizer),
            compute_metrics=compute_metrics,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics
        )
        trainer.train()
   
    # Call trainer
    FLAGS={}
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

I also update the padding of tokenizer to max_length

    tokenizer = create_tokenizer(data, max_seq_len=1024, exp_id=Debug)
    tokenizer.padding = "max_length"

with create_tokenizer function:

def create_tokenizer(data, max_seq_len, exp_id):
    print("Training Tokenizer...")
    tokenizer = Tokenizer(models.WordLevel()) #turn raw data into numbers, splitting it into words or subwords, which then are converted to ids, Most simple tokenizer model based on mapping tokens to their corresponding id.
    tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
    trainer = trainers.WordLevelTrainer(#Trainer capable of training a WorldLevel model
        special_tokens=[
            "<unk>",
            "<pad>",
            "<MASK>"
        ]
    )
    # pdb.set_trace()
    seqs_lengths = [len (v.split(" ")) for v in data["seq"].values]
    tokenizer.train_from_iterator(data["seq"].values, trainer=trainer)
    print("Vocabulary Size: {}".format(len(tokenizer.get_vocab()))) #total number of words 
    tokenizer.save("data/{}-tokenizer.json".format(exp_id))
    tokenizer = PreTrainedTokenizerFast(tokenizer_file="data/{}-tokenizer.json".format(exp_id))
    tokenizer.pad_token = '<pad>'
    tokenizer.model_max_length = max_seq_len
    return tokenizer
  1. Run the training script (with dataset size is 1000)

PJRT_DEVICE=TPU python trainer.py 1000

Sometimes it managed to train (around 8 mins for 1 epoch), sometimes it crashes with the following logs:

https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map= 
*** SIGABRT received by PID 71179 (TID 71179) on cpu 47 from PID 71179; stack trace: ***
PC: @     0x7f4610c969fc  (unknown)  pthread_kill
    @     0x7f438ebaa53a       1152  (unknown)
    @     0x7f4610c42520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0 
E1218 18:27:36.783398   71179 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.783418   71179 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.783425   71179 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.783432   71179 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.783459   71179 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.783467   71179 coredump_hook.cc:603] RAW: Dumping core locally.
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map= 
*** SIGABRT received by PID 71178 (TID 71178) on cpu 72 from PID 71178; stack trace: ***
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map= 
*** SIGABRT received by PID 71180 (TID 71180) on cpu 46 from PID 71180; stack trace: ***
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map= 
*** SIGABRT received by PID 71181 (TID 71181) on cpu 76 from PID 71181; stack trace: ***
PC: @     0x7f4610c969fc  (unknown)  pthread_kill
    @     0x7f438ebaa53a       1152  (unknown)
PC: @     0x7f4610c969fc  (unknown)  pthread_kill
    @     0x7f438ebaa53a       1152  (unknown)
    @     0x7f4610c42520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0 
E1218 18:27:36.792775   71180 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.792791   71180 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.792798   71180 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.792805   71180 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.792826   71180 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.792834   71180 coredump_hook.cc:603] RAW: Dumping core locally.
    @     0x7f4610c42520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0 
E1218 18:27:36.792896   71178 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.792911   71178 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.792918   71178 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.792925   71178 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.792948   71178 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.792956   71178 coredump_hook.cc:603] RAW: Dumping core locally.
PC: @     0x7f4610c969fc  (unknown)  pthread_kill
    @     0x7f438ebaa53a       1152  (unknown)
    @     0x7f4610c42520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0 
E1218 18:27:36.793691   71181 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.793706   71181 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.793713   71181 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.793720   71181 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.793741   71181 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.793749   71181 coredump_hook.cc:603] RAW: Dumping core locally.
E1218 18:27:36.975384   71179 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:36.980584   71180 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:36.981091   71181 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:37.091601   71178 process_state.cc:783] RAW: Raising signal 6 with default behavior
Traceback (most recent call last):
  File "/home/anh-dung.le/test-health-llm/trainer.py", line 148, in <module>
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 202, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 159, in run_multiprocess
    replica_results = list(
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 160, in <genexpr>
    itertools.chain.from_iterable(
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/process.py", line 570, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

I would like to know if my above approach for the training using Trainer API within TPU VM is appropriate, or if there are some missing points I need to add up (as performance-wise, it doesn’t feel so great compared to 4 A100 I was testing) ?

In order to benchmark for performance, I have to install and run a Profiling tool for Pytorch XLA and also update the code above if I understand correctly?

Training with Trainer API in TPU PODS

In addition to this, I’m willing to test performance while running the training in the TPU PODs (V2-32) by following this tutorial: Run PyTorch code on TPU Pod slices  |  Google Cloud, from what I understand, we have to include the --worker=all in the gcp , something like this:

gcloud compute tpus tpu-vm ssh $vmName --zone $location --worker=all --command="PJRT_DEVICE=TPU python3 trainer.py $dataset_size > $logs_output" 

Is this way correct? If so, how can TPU PODs split the training works into the TPU VMs and gather them all together after? Also, what are the correct TrainingArguments for tpu_num_cores should I use, 8 or 32?

Thanks a lot for your help and sorry if it takes so long, I’m pretty new (and very excited with all these :slight_smile: )
Cheers