Bypassing "CUDA error: unspecified launch failure" error from trainer checkpoint loading

I face some errors when I use trainer to train and load the model, but was not able to isolate the exact issue, so I don’t know how to reproduce it. Even for the same model, only some configurations show this error. For example, I have this model:

from collections import OrderedDict
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from sklearn.metrics import balanced_accuracy_score

from tabrfm.data import DataProcessor

from .utils import compute_loss


class PlainMlpClassifierConfig(PretrainedConfig):
    def __init__(
        self,
        d_emb: int = 256,
        d_hidden: int = 8096,
        n_hidden: int = -1,
        loss_name: str = "cross_entropy",
        **kwargs,
    ):
        self.d_emb = d_emb
        self.d_hidden = d_hidden
        self.n_hidden = n_hidden
        self.loss_name = loss_name
        super().__init__(**kwargs)


class PlainMlpClassifier(PreTrainedModel):
    config_class = PlainMlpClassifierConfig

    def __init__(self, config: PlainMlpClassifierConfig) -> None:
        super().__init__(config)

    def adapt(self, processor: DataProcessor):
        processor.load()
        feature_shapes = processor.get("feature_shapes", return_tensor=True)
        self.d_out = int(feature_shapes[processor.target_col].sum())
        self.d_seq = int(sum(feature_shapes) - self.d_out)
        self.n_cols = int(len(processor.col_names) - 1)
        self.feature_shapes = feature_shapes
        self.target_mask = processor.get("target_mask", return_tensor=True)
        return self

    def build(self):
        self.embeddings = nn.Embedding(
            num_embeddings=self.d_seq,
            embedding_dim=self.config.d_emb,
        )

        clf_layers = [
            (
                "input",
                nn.Linear(
                    self.config.d_emb * self.n_cols,
                    (self.config.d_hidden if self.config.n_hidden > -1 else self.d_out),
                ),
            ),
        ]

        for i in range(self.config.n_hidden):
            clf_layers += [
                (
                    f"hidden_{i}_linear",
                    nn.Linear(self.config.d_hidden, self.config.d_hidden),
                ),
                (f"hidden_{i}_relu", nn.ReLU()),
            ]

        if self.config.n_hidden > -1:
            clf_layers.append(
                (
                    "output",
                    nn.Linear(self.config.d_hidden, self.d_out),
                )
            )

        self.classifier = nn.Sequential(OrderedDict(clf_layers))
        return self

    def forward(
        self,
        feature_idxs: torch.LongTensor,
        feature_vals: torch.FloatTensor,
        labels=None,
    ):
        feature_shapes = self.feature_shapes.to(feature_idxs.device)
        target_mask = self.target_mask.to(feature_idxs.device)

        offset = torch.zeros(len(target_mask) - target_mask.sum(), dtype=torch.long).to(
            feature_idxs
        )
        offset[1:] = feature_shapes[~target_mask].cumsum(0)[:-1]

        x_idx = feature_idxs[:, ~target_mask] + offset
        x_val = feature_vals[:, ~target_mask]

        x = self.embeddings(x_idx) * x_val.unsqueeze(-1)
        x = x.flatten(1)
        logits = self.classifier(x)

        out = dict(
            logits=logits,
        )

        if labels is not None:
            loss = compute_loss(
                self.config.loss_name,
                out=logits,
                target=feature_idxs[:, target_mask].squeeze(dim=1),
            )
            out["loss"] = loss

        return out

    @staticmethod
    def compute_metrics(preds):
        return dict(
            balanced_accuracy=balanced_accuracy_score(
                preds.label_ids, preds.predictions.argmax(1)
            )
        )

Where processor is a custom dataset loader I implemented.
When I try to load the model manually in a single-gpu/cpu setting, I do not see any errors. However, when trainer is using all the GPUs on the node, I sometimes get this error:

Traceback (most recent call last):
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/gpfs/u/scratch/DDTD/DDTDkngn/tab-reprog-fm/tabrfm/scripts/train_rfm.py", line 161, in <module>
    main()
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/gpfs/u/scratch/DDTD/DDTDkngn/tab-reprog-fm/tabrfm/scripts/train_rfm.py", line 149, in main
    trainer.train(resume_from_checkpoint=can_continue)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/transformers/trainer.py", line 1904, in train
    self._load_from_checkpoint(resume_from_checkpoint)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/transformers/trainer.py", line 2575, in _load_from_checkpoint
    load_result = model.load_state_dict(state_dict, False)
  File "/gpfs/u/home/DDTD/DDTDkngn/scratch/miniconda-ppc/envs/tabrfm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PlainMlpClassifier:
        While copying the parameter named "embeddings.weight", whose dimensions in the model are torch.Size([153, 256]) and whose dimensions in the checkpoint are torch.Size([153, 256]), an exception occ
urred : ('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing C
UDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n',).
        While copying the parameter named "classifier.0.weight", whose dimensions in the model are torch.Size([2, 5120]) and whose dimensions in the checkpoint are torch.Size([2, 5120]), an exception occ
urred : ('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing C
UDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n',).
        While copying the parameter named "classifier.0.bias", whose dimensions in the model are torch.Size([2]) and whose dimensions in the checkpoint are torch.Size([2]), an exception occurred : ('CUDA
 error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLO
CKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n',).

And setting CUDA_LAUNCH_BLOCKING=1 gives the same issue, where the main issue is this CUDA error: unspecified launch failure.

However, I found that if I manually add some sanity check in the source code of trainer.py, I can reliably bypass this error.

I tested with checking the device, whether the weight is nan or inf (which it never was), and was able to reduce my checks to this

for k,v in state_dict.items():
    if torch.isnan(v).any():
        ... # do something
    if torch.isinf(v).any():
        ... # do something

Placing this code in the appropriate places (where my model loads) inside _load_from_checkpoint and _load_best_checkpoint now gets rid of this error.

I wanted to share my finding, and would also like to find out why actually this works.