Hi everyone,
I am trying to achieve the following objectives:
- Running existing training scripts using
Transformer with Trainer API into TPU VM (v2-8 or v3-8)
- Running the set-up for TPU Pods (v2-32)
Training with Trainer API in TPU VM
The existing training script, it can train with Nvidia A100 using the TrainingArguments and Trainer API, to adapt with TPU VM, I did the following:
- Install additionally the pytorch and pytorch xla
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
- Update the existing training codes: adding
tpu_num_cores
argument, move the training code inside _mp_fn function as below
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(rank, flags):
training_args = TrainingArguments(
output_dir=model_dir,
learning_rate=1e-5, #The initial learning rate for AdamW optimizer. step size taken during training to update the model’s weights
num_train_epochs=1, #default 3
per_device_train_batch_size=8, #The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
per_device_eval_batch_size=8, #The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
gradient_accumulation_steps=1, #default 1 Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
save_total_limit=2, # If a value is passed, will limit the total amount of checkpoints. For example, for save_total_limit=5 and load_best_model_at_end, the four last checkpoints will always be retained alongside the best model. When save_total_limit=1 and load_best_model_at_end, it is possible that two checkpoints are saved: the last one and the best one (if they are different).
evaluation_strategy="no", #Evaluation is done (and logged) every eval_steps.
eval_steps=100,#Number of update steps between two evaluations
save_strategy="no",#Save is done every save_steps.
logging_steps=100, #Number of update steps between two logs
remove_unused_columns=False, #Whether or not to automatically remove the columns unused by the model forward method.
push_to_hub=False, #Whether or not to push the model to the Hub every time the model is saved
label_names=["labels"], # The list of keys in your dictionary of inputs that correspond to the labels.
load_best_model_at_end=True,
dataloader_num_workers=12, #as suggested from the system Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.
report_to="tensorboard",
logging_dir=logging_dir,#tensorboard log directory
tpu_num_cores=8, # define number of TPU cores
dataloader_drop_last=True #Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not
)
if loader_args["model"] == "BERT":
# Some custom model definition for BERT
elif loader_args["model"] == "GPT2":
# Some custom model definition for GPT2
else:
print("Wrong model specified..Exiting")
exit(1)
# Create Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
data_collator = DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
)
trainer.train()
# Call trainer
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
I also update the padding of tokenizer to max_length
tokenizer = create_tokenizer(data, max_seq_len=1024, exp_id=Debug)
tokenizer.padding = "max_length"
with create_tokenizer
function:
def create_tokenizer(data, max_seq_len, exp_id):
print("Training Tokenizer...")
tokenizer = Tokenizer(models.WordLevel()) #turn raw data into numbers, splitting it into words or subwords, which then are converted to ids, Most simple tokenizer model based on mapping tokens to their corresponding id.
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
trainer = trainers.WordLevelTrainer(#Trainer capable of training a WorldLevel model
special_tokens=[
"<unk>",
"<pad>",
"<MASK>"
]
)
# pdb.set_trace()
seqs_lengths = [len (v.split(" ")) for v in data["seq"].values]
tokenizer.train_from_iterator(data["seq"].values, trainer=trainer)
print("Vocabulary Size: {}".format(len(tokenizer.get_vocab()))) #total number of words
tokenizer.save("data/{}-tokenizer.json".format(exp_id))
tokenizer = PreTrainedTokenizerFast(tokenizer_file="data/{}-tokenizer.json".format(exp_id))
tokenizer.pad_token = '<pad>'
tokenizer.model_max_length = max_seq_len
return tokenizer
- Run the training script (with dataset size is 1000)
PJRT_DEVICE=TPU python trainer.py 1000
Sometimes it managed to train (around 8 mins for 1 epoch), sometimes it crashes with the following logs:
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map=
*** SIGABRT received by PID 71179 (TID 71179) on cpu 47 from PID 71179; stack trace: ***
PC: @ 0x7f4610c969fc (unknown) pthread_kill
@ 0x7f438ebaa53a 1152 (unknown)
@ 0x7f4610c42520 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0
E1218 18:27:36.783398 71179 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.783418 71179 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.783425 71179 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.783432 71179 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.783459 71179 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.783467 71179 coredump_hook.cc:603] RAW: Dumping core locally.
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map=
*** SIGABRT received by PID 71178 (TID 71178) on cpu 72 from PID 71178; stack trace: ***
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map=
*** SIGABRT received by PID 71180 (TID 71180) on cpu 46 from PID 71180; stack trace: ***
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f4610c4251f&map=
*** SIGABRT received by PID 71181 (TID 71181) on cpu 76 from PID 71181; stack trace: ***
PC: @ 0x7f4610c969fc (unknown) pthread_kill
@ 0x7f438ebaa53a 1152 (unknown)
PC: @ 0x7f4610c969fc (unknown) pthread_kill
@ 0x7f438ebaa53a 1152 (unknown)
@ 0x7f4610c42520 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0
E1218 18:27:36.792775 71180 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.792791 71180 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.792798 71180 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.792805 71180 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.792826 71180 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.792834 71180 coredump_hook.cc:603] RAW: Dumping core locally.
@ 0x7f4610c42520 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0
E1218 18:27:36.792896 71178 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.792911 71178 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.792918 71178 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.792925 71178 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.792948 71178 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.792956 71178 coredump_hook.cc:603] RAW: Dumping core locally.
PC: @ 0x7f4610c969fc (unknown) pthread_kill
@ 0x7f438ebaa53a 1152 (unknown)
@ 0x7f4610c42520 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f4610c969fc,7f438ebaa539,7f4610c4251f&map=abbd016d9542b8098892badc0b19ea68:7f4381a00000-7f438edbecf0
E1218 18:27:36.793691 71181 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1218 18:27:36.793706 71181 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1218 18:27:36.793713 71181 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1218 18:27:36.793720 71181 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1218 18:27:36.793741 71181 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1218 18:27:36.793749 71181 coredump_hook.cc:603] RAW: Dumping core locally.
E1218 18:27:36.975384 71179 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:36.980584 71180 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:36.981091 71181 process_state.cc:783] RAW: Raising signal 6 with default behavior
E1218 18:27:37.091601 71178 process_state.cc:783] RAW: Raising signal 6 with default behavior
Traceback (most recent call last):
File "/home/anh-dung.le/test-health-llm/trainer.py", line 148, in <module>
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper
return fn(*args, **kwargs)
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 202, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper
return fn(*args, **kwargs)
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 159, in run_multiprocess
replica_results = list(
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 160, in <genexpr>
itertools.chain.from_iterable(
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/process.py", line 570, in _chain_from_iterable_of_lists
for element in iterable:
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/home/anh-dung.le/miniconda3/envs/health-llm/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
I would like to know if my above approach for the training using Trainer API within TPU VM is appropriate, or if there are some missing points I need to add up (as performance-wise, it doesn’t feel so great compared to 4 A100 I was testing) ?
In order to benchmark for performance, I have to install and run a Profiling tool for Pytorch XLA and also update the code above if I understand correctly?
Training with Trainer API in TPU PODS
In addition to this, I’m willing to test performance while running the training in the TPU PODs (V2-32) by following this tutorial: Run PyTorch code on TPU Pod slices | Google Cloud, from what I understand, we have to include the --worker=all
in the gcp
, something like this:
gcloud compute tpus tpu-vm ssh $vmName --zone $location --worker=all --command="PJRT_DEVICE=TPU python3 trainer.py $dataset_size > $logs_output"
Is this way correct? If so, how can TPU PODs split the training works into the TPU VMs and gather them all together after? Also, what are the correct TrainingArguments for tpu_num_cores
should I use, 8 or 32?
Thanks a lot for your help and sorry if it takes so long, I’m pretty new (and very excited with all these )
Cheers