I am training TFT model from Pytorch Forecasting. But the problem is I am facing memory issues. Memory usage is rising at every batch iteration until end of first epoch and then stay at that level. I tried with different batch sizes, model parameters and smaller datasets but nothing changed. I am training on CPU with Google colab with 51 GB of memory but it is crashing before than second epoch. Here is my config;
# Let's create a Dataset
training = TimeSeriesDataSet(
train,
time_idx="time_idx",
target="target",
group_ids=["group_id"],
max_encoder_length=120,
max_prediction_length=1,
static_categoricals=[],
time_varying_known_categoricals=time_category_columns,
time_varying_known_reals=["time_idx"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=time_reals,
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True,
)
# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, train, predict=True, stop_randomization=True)
# create dataloaders for model
batch_size = 64 # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
early_stop_callback = EarlyStopping(monitor="train_loss", min_delta=1e-2, patience=PATIENCE, verbose=False, mode="min")
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard
trainer = pl.Trainer(
max_epochs=MAX_EPOCHS,
enable_model_summary=True,
gradient_clip_val=0.25,
limit_train_batches=1.0
logger=logger,
accumulate_grad_batches=1,
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=LEARNING_RATE,
lstm_layers=2,
hidden_size=64,
attention_head_size=4,
dropout=0.1,
hidden_continuous_size=64,
output_size=1,
loss=SMAPE(),
reduce_on_plateau_patience=4,
log_interval=-1
)
tft.to(DEVICE)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
Did anyone get same issue or have any idea for how can i solve this?