Multiple call datasets.load_from_disk() cause Memory Leak!

My dataset is stored on HDFS and the size is too large to save on local disk.
Using load_from_disk to pull them all down and then concat them will be a waste of time, especially in the case of a large number of workers in distributed training.

So I implemented an IterableDataset to load a file from hdfs at a time, the code below:

class StreamFileDataset(IterableDataset):
  def __init__(self, data_dir, cycle_mode=False):
    self.data_dir = data_dir
    self.cycle_mode = cycle_mode
    self._is_init = False

  def _config_fs(self):
    if self.data_dir.startswith("hdfs://"):
      self.fs = HadoopFileSystem()
      self.data_dir = self.data_dir.replace("hdfs:/", "")
      self.data_files = sorted(self.fs.ls(self.data_dir))
    else:
      self.fs = None
      self.data_files = sorted(glob.glob(os.path.join(self.data_dir, "*")))

  def _config_multi_worker(self):
    worker_info = data.get_worker_info()
    if worker_info is not None:
      worker_id = worker_info.id
      num_worker = worker_info.num_workers
      indexes = range(worker_id, len(self.data_files), num_worker)
      self.data_files = [self.data_files[i] for i in indexes]

    if self.cycle_mode:
      self.data_files = itertools.cycle(self.data_files)

  def _init(self):
    if not self._is_init:
      self._config_fs()
      self._config_multi_worker()
      self._is_init = True

  def __iter__(self):
    self._init()
    for data_file in self.data_files:
      data = datasets.load_from_disk(data_file, fs=self.fs)
      for d in data:
        yield d
      # Manually delete data to avoid memory leaks
      del data

But bad things happen now: there is a memory leak here!

The memory increase in the image above happens when load_from_disk reads the next file

Then I did a test:

for data_file in self.data_files:
      print("before")
      print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
      data = datasets.load_from_disk(data_file, self.fs)
      print("after")
      print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

The memory is gradually growing!!

I also experimented at the same time, even if the data is saved locally, there will be a memory leak.

Is this a bug, or is there any other solution?