Gradients in Data Collator cause Memory Leak

Hi there, I am training a BERT and GNN model end-to-end for some Graph Representation Learning Task on Knowledge Graphs. Producing the graph embeddings with gradients in the DataCollator and stacking them without deleting the original tensors manually, leads to memory leak.

My code snippet of my custom _convert_features_into_batches:

    source_embeddings, target_embeddings = self.get_embeddings_cb(
            self.data, source_ids, target_ids
        )
        graph_embeddings = torch.stack([source_embeddings, target_embeddings], dim=1)
        del source_embeddings, target_embeddings
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "graph_embeddings": graph_embeddings,
            "semantic_positional_encoding": semantic_positional_encoding,
        }

My system:

  • transformers version: 4.42.3
  • Platform: Windows-10-10.0.19045-SP0
  • Python version: 3.12.4
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: False
  • Using GPU in script?: True
  • GPU type: NVIDIA GeForce RTX 2080 SUPER

Disclaimer: I know it makes more sense to produce the embeddings inside the models forward method, which I will do eventually. I just wanted to let you know that I had this issue. I think gradients in memory that are not released are a general concern and I understand if this is not an issue that needs to be fixed by Huggingface.

Greetings, Ahmad

1 Like

Hello.
The problem with the pytorch tensor not disappearing easily is tricky.
I think it would be beneficial to share the know-how on how to work around it.

@not-lain Do you have any idea of current active HF staff who maintain the base class of the HF library that directly touches the pytorch area?
For example, even if there were a problem with from_pretrained() or .to(), there are too many potential people to consult.
I know some of the staff who are working on the demo, too, because they stand out…

Well, we could get a github or discord account, but it would be more useful if we could report problems that occur in HF within HF.

Hi @AhmadPython
the problem is with this line

the way pytorch works is that even if you delete the python variables it will not delete the cuda cache (try deleting a model and it will still be allocated in memory), to fix this, consider adding these lines after you delete your variable

import gc
torch.cuda.empty_cache()
gc.collect()
2 Likes

@John6666 and @not-lain thanks for your help,
I think I found the original culprit of this problem tho I think your solution would have solved the memory leak.

The problem was, that my Torch.Geometrics GNN was sitting on CPU and the Huggingface Model on GPU. I guess what happend was, as soon as the tensor moves to another device, it becomes a leaf tensor and then the gradient flow will be lost. This made the original tensor stuck in memory, as @not-lain already suggested.

My solution was to move both tensors to the same device.

That being said, I am glad this memory leak occured, because else, I would not have noticed, that I have to be really careful with the gradient flow. For example, at the point when I merge LLM embeddings with GNN embeddings, I use torch.scatter, which will not track the gradients to the source tensor by default.
Instead I will have to use slicing.

mask = (
semantic_positional_encoding[:, [-4, -2], 0]
.unsqueeze(-1)
.repeat((1, 1, inputs_embeds.shape[-1]))
)
inputs_embeds = inputs_embeds.scatter(
1, mask, graph_embeddings
) # replace the input embeds at the place holder positions with the KGEs.

Anyways, thanks for the help!

2 Likes

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.