Pytorch Lightning - Memory Leak

I am training TFT model from Pytorch Forecasting. But the problem is I am facing memory issues. Memory usage is rising at every batch iteration until end of first epoch and then stay at that level. I tried with different batch sizes, model parameters and smaller datasets but nothing changed. I am training on CPU with Google colab with 51 GB of memory but it is crashing before than second epoch. Here is my config;

# Let's create a Dataset
training = TimeSeriesDataSet(
    train,
    time_idx="time_idx",
    target="target",
    group_ids=["group_id"],
    max_encoder_length=120,
    max_prediction_length=1,
    static_categoricals=[],
    time_varying_known_categoricals=time_category_columns,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=time_reals,
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, train, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 64  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
early_stop_callback = EarlyStopping(monitor="train_loss", min_delta=1e-2, patience=PATIENCE, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard


trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    enable_model_summary=True,
    gradient_clip_val=0.25,
    limit_train_batches=1.0
    logger=logger,
    accumulate_grad_batches=1,
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=LEARNING_RATE,
    lstm_layers=2,
    hidden_size=64,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=64,
    output_size=1,
    loss=SMAPE(),
    reduce_on_plateau_patience=4,
    log_interval=-1
)

tft.to(DEVICE)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

Did anyone get same issue or have any idea for how can i solve this?