hey everyone,
I’m trying to fine-tune a TFBertForSequenceClassification on TPU using TensorFlow.
i run the model on GCP with TPU v3-8.
i fine tune a pretrained bert-large model.
i want to use a batch size of 128 texts * 512 tokens, but I get an out-of-memory exception for the TPU(it works for a batch size of 64).
i load the data using tfrecords files and cloud storage.
does it make sense that i can not use a batch size of 128?
any tips how to use less memory.
thanks in advance,
tomer
here is my code :
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from transformers import HfArgumentParser, TFTrainingArguments
from dataclasses import dataclass, field
import json
import os
# region Helper classes
class SavePretrainedCallback(tf.keras.callbacks.Callback):
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
# that saves the model with this method after each epoch.
def __init__(self, output_dir, **kwargs):
super().__init__()
self.output_dir = output_dir
def on_epoch_end(self, epoch, logs=None):
# self.model.save_weights("{}/tpu-model.h5".format(self.output_dir))
self.model.save_pretrained(self.output_dir)
FEATURE_DESC = {
'input_ids': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'attention_mask': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'token_type_ids': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'labels': tf.io.FixedLenFeature([], tf.int64)
}
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
example = tf.io.parse_single_example(example_proto, FEATURE_DESC)
input_ids = tf.cast(example['input_ids'], tf.int32)
attention_mask = tf.cast(example['input_ids'], tf.int32)
attention_mask = tf.cast(example['attention_mask'], tf.int32)
token_type_ids = tf.cast(example['token_type_ids'], tf.int32)
labels = tf.cast(example['labels'], tf.int32)
return {'input_ids': input_ids, 'attention_mask': attention_mask,
'token_type_ids': token_type_ids}, labels
def create_training_dataset(data_args, training_args):
train_files = tf.io.gfile.glob(data_args.train_files)
raw_train_data = tf.data.TFRecordDataset(train_files, num_parallel_reads=tf.data.AUTOTUNE)
parsed_train_dataset = raw_train_data.map(_parse_function)
train_dataset = parsed_train_dataset
# train_dataset = parsed_train_dataset.shuffle(buffer_size=len(parsed_train_dataset))
if training_args.max_steps > -1:
max_train_size = training_args.max_steps * training_args.train_batch_size
print("number of train samples:", max_train_size)
train_dataset = train_dataset.take(max_train_size)
train_dataset = train_dataset.batch(training_args.train_batch_size)
return train_dataset
def create_model(model_name):
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, from_pt=True)
return model
@dataclass
class DataArguments:
train_files: str = "records_train/*.tfrecords"
test_files: str = "/records_dev/*.tfrecords"
if __name__ == "__main__":
parser = HfArgumentParser((DataArguments, TFTrainingArguments))
data_args, training_args = parser.parse_args_into_dataclasses()
# Create a description of the features.
model_name =#name of ber model on hub
if not os.path.exists(training_args.output_dir):
os.mkdir(training_args.output_dir)
with training_args.strategy.scope():
# validation_dataset = parsed_train_dataset.skip(max_train_size)
model = create_model(model_name)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ['accuracy']
# metrics = metrics
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=training_args.learning_rate), loss=loss,
metrics=metrics)
if training_args.do_train:
train_dataset = create_training_dataset(data_args, training_args)
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
history = model.fit(train_dataset, epochs=int(training_args.num_train_epochs), verbose=1,
callbacks=callbacks)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(training_args.output_dir)
print(history.history)
if training_args.do_eval:
test_files = tf.io.gfile.glob(data_args.test_files)
raw_test_data = tf.data.TFRecordDataset(test_files, num_parallel_reads=tf.data.AUTOTUNE)
test_dataset = raw_test_data.map(_parse_function)
print("eval model!!!")
if training_args.eval_steps > -1:
max_eval_samples = training_args.eval_steps * training_args.eval_batch_size
test_dataset = test_dataset.take(max_eval_samples)
res = model.evaluate(test_dataset.batch(training_args.eval_batch_size), return_dict=True)
print("eval res:", res)
with open("{}/{}_eval_res.json".format(training_args.output_dir, training_args.run_name), 'w') as f:
json.dump(res, f)