InvalidArgumentError when training Segformer

I am getting some errors when I try to fit the model, I don’t know what this is about. Any help would be appreciated.

# %%
!pip install -r requirements.txt -q

# %%
from datasets import load_dataset, load_metric

# %%
import json

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from datasets import load_dataset, load_metric
from huggingface_hub import cached_download, hf_hub_url

from transformers import (
    DefaultDataCollator,
    SegformerFeatureExtractor,
    TFSegformerForSemanticSegmentation,
    create_optimizer,
)
from transformers.keras_callbacks import KerasMetricCallback

# %%
epochs = 2
lr = 0.00006
batch_size = 2

# %%
from datasets import Dataset, DatasetDict, Image, load_metric
import os


# your images can of course have a different extension
# semantic segmentation maps are typically stored in the png format
image_paths_train = [f'training/img/{i}' for i in os.listdir('training/img')]
label_paths_train = [f'training/mask/{i}' for i in os.listdir('training/mask')]

#distribute the data into train and validation
train_size = int(len(image_paths_train)*0.8)

image_paths_validation = image_paths_train[train_size:]
label_paths_validation = label_paths_train[train_size:]

image_paths_train = image_paths_train[:train_size]
label_paths_train = label_paths_train[:train_size]


def create_dataset(image_paths, label_paths):
    dataset = Dataset.from_dict({"pixel_values": sorted(image_paths),
                                "label": sorted(label_paths)})
    
    dataset = dataset.cast_column("pixel_values", Image())
    dataset = dataset.cast_column("label", Image())
    
    return dataset


train_dataset = create_dataset(image_paths_train, label_paths_train)
validation_dataset = create_dataset(image_paths_validation, label_paths_validation)

# %%
print(train_dataset)

# %%
feature_extractor = SegformerFeatureExtractor()


def transforms(image):
    image = tf.keras.utils.img_to_array(image)
    image = image.transpose((2, 0, 1))
    return image


def preprocess(example_batch):
    images = [transforms(x.convert("RGB")) for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = feature_extractor(images, labels)
    return inputs

# %%
train_dataset.set_transform(preprocess)
validation_dataset.set_transform(preprocess)

# %%
data_collator = DefaultDataCollator(return_tensors="tf")

train_set = train_dataset.to_tf_dataset(
    columns=["pixel_values", "label"],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator,
)

val_set = validation_dataset.to_tf_dataset(
    columns=["pixel_values", "label"],
    shuffle=False,
    batch_size=batch_size,
    collate_fn=data_collator,
)

# %%
train_set.element_spec


# %%
# Investigate a single batch.
batch = next(iter(train_set))
batch["pixel_values"].shape, batch["labels"].shape

# %%
import json
# simple example
id2label = {0: 'non-landslide', 1: 'landslide'}
with open('id2label.json', 'w') as fp:
    json.dump(id2label, fp)

label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

# %%
metric = load_metric("mean_iou")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # logits are of shape (batch_size, num_labels, height, width), so
    # we first transpose them to (batch_size, height, width, num_labels)
    logits = tf.transpose(logits, perm=[0, 2, 3, 1])
    # scale the logits to the size of the label
    logits_resized = tf.image.resize(
        logits,
        size=tf.shape(labels)[1:],
        method="bilinear",
    )
    # compute the prediction labels and compute the metric
    pred_labels = tf.argmax(logits_resized, axis=-1)
    metrics = metric.compute(
        predictions=pred_labels,
        references=labels,
        num_labels=num_labels,
        ignore_index=-1,
        reduce_labels=feature_extractor.reduce_labels,
    )
    return {"val_" + k: v for k, v in metrics.items()}


metric_callback = KerasMetricCallback(
    metric_fn=compute_metrics,
    eval_dataset=val_set,
    batch_size=batch_size,
    label_cols=["labels"],
)

# %%

model_checkpoint = "nvidia/mit-b0"
model = TFSegformerForSemanticSegmentation.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # Will ensure the segmentation specific components are reinitialized.
)


model.compile(optimizer='adam')

# %%
model.summary()

# %%
callbacks = [metric_callback] 


# %%
model.fit(
    train_set,
    validation_data=val_set,
    callbacks=callbacks,
    epochs=epochs,
)

# %%

The error message:


2024-03-21 15:12:14.288080: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape intf_segformer_for_semantic_segmentation/decode_head/dropout_24/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2024-03-21 15:12:16.624711: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:458] Loaded runtime CuDNN library: 8.7.0 but source was compiled with: 8.9.6.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
2024-03-21 15:12:16.625573: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at conv_ops_impl.h:1201 : INVALID_ARGUMENT: No DNN in stream executor.
2024-03-21 15:12:16.625616: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: No DNN in stream executor.
	 [[{{node tf_segformer_for_semantic_segmentation/segformer/encoder/patch_embeddings.0/proj/Conv2D}}]]
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Cell In[17], line 1
----> 1 model.fit(
      2     train_set,
      3     validation_data=val_set,
      4     callbacks=callbacks,
      5     epochs=epochs,
      6 )

File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py:1170, in TFPreTrainedModel.fit(self, *args, **kwargs)
   1167 @functools.wraps(keras.Model.fit)
   1168 def fit(self, *args, **kwargs):
   1169     args, kwargs = convert_batch_encoding(*args, **kwargs)
-> 1170     return super().fit(*args, **kwargs)

File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node tf_segformer_for_semantic_segmentation/segformer/encoder/patch_embeddings.0/proj/Conv2D defined at (most recent call last):
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/runpy.py", line 86, in _run_code

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 531, in process_one

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 359, in execute_request

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 775, in execute_request

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 446, in do_execute

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/tmp/ipykernel_25233/3682213347.py", line 1, in <module>

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1170, in fit

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 1804, in fit

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 1398, in train_function

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 1381, in step_function

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 1370, in run_step

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1610, in train_step

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1613, in train_step

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 588, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1047, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__

  File "/tmp/__autograph_generated_filehsh0fy6b.py", line 34, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 970, in run_call_with_unpacked_inputs

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 997, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1047, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__

  File "/tmp/__autograph_generated_filehsh0fy6b.py", line 34, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 970, in run_call_with_unpacked_inputs

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 605, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1047, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__

  File "/tmp/__autograph_generated_filehsh0fy6b.py", line 34, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 519, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 522, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1047, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__

  File "/tmp/__autograph_generated_filehsh0fy6b.py", line 34, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 100, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 553, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 558, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1047, in __call__

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__

  File "/tmp/__autograph_generated_filehsh0fy6b.py", line 34, in error_handler

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/layers/convolutional/base_conv.py", line 284, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/layers/convolutional/base_conv.py", line 289, in call

  File "/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tf_keras/src/layers/convolutional/base_conv.py", line 268, in convolution_op

No DNN in stream executor.
	 [[{{node tf_segformer_for_semantic_segmentation/segformer/encoder/patch_embeddings.0/proj/Conv2D}}]] [Op:__inference_train_function_32710]