InvalidArgumentError with vit-base-patch16-224 model?

Hi. So I am using a pretrained model based on google’s vit-base-patch16-224-in21k for a binary classification of images (human vs non human) .
I am using Keras/tensorflow 2.6.0 API.

here are some parts of my code.

# Downloading the base model
base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

# Flipping and rotating images
data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)
# Freeze base model
base_model.trainable = False
# Create new model
inputs = keras.Input(shape = (3, 224, 224))
x = data_augmentation(inputs)   # apply data augmentation

x = base_model(x, training=False)[0]
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)


# model
model_vit = tf.keras.Model(inputs, outputs)
model_vit.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])

model_vit.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_6 (InputLayer)         [(None, 3, 224, 224)]     0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 3, 224, 224)       0         
_________________________________________________________________
tf_vi_t_model (TFViTModel)   TFBaseModelOutputWithPool 86389248  
_________________________________________________________________
dense_2 (Dense)              (None, 197, 1)            769       
=================================================================
Total params: 86,390,017
Trainable params: 769
Non-trainable params: 86,389,248

There are lots of non-trainable parameters by the way.

When I run the training I get this error:

# Train the Vit model
vit_trained_model = model_vit.fit( X_train_images, y_train_labels, validation_data=(X_val_images, y_val_labels), batch_size = 8, verbose=2, epochs=50)

scores = model_vit.evaluate(test_images, test_labels_binary, verbose=0)
print("ViT Model Accuracy on Test Set: %.2f%%" % (scores[1]*100))


---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_28616\3601201585.py in <cell line: 2>()
      1 # Train the Vit model
----> 2 vit_trained_model = model_vit.fit( X_train_images, y_train_labels, validation_data=(X_val_images, y_val_labels), batch_size = 8, verbose=2, epochs=50)
      3 
      4 scores = model_vit.evaluate(test_images, test_labels_binary, verbose=0)
      5 print("Xception Accuracy on Test Set: %.2f%%" % (scores[1]*100))
....

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\constant_op.py in convert_to_eager_tensor(value, ctx, dtype)
    104       dtype = dtypes.as_dtype(dtype).as_datatype_enum
    105   ctx.ensure_initialized()
--> 106   return ops.EagerTensor(value, ctx.device_name, dtype)
    107 
    108 

InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.

This error most likely means that my computer is out of memory. Fine, another approach is to use generators and define the batch size manually:

from tensorflow.keras.utils import Sequence
import numpy as np   

class DataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y

train_gen = DataGenerator(X_train_images, y_train_labels, 16)
test_gen = DataGenerator(X_val_images, y_val_labels, 16)

history = model_vit.fit(train_gen,
                    epochs=6,
                    validation_data=test_gen)

this gives the error relating to input shapes?

Epoch 1/6
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_13140\3904230856.py in <cell line: 21>()
     19 
     20 
---> 21 history = model_vit.fit(train_gen,
     22                     epochs=6,
     23                     validation_data=test_gen)

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1182                 _r=1):
   1183               callbacks.on_train_batch_begin(step)
-> 1184               tmp_logs = self.train_function(iterator)
   1185               if data_handler.should_sync:
   1186                 context.async_wait()

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    948         # Lifting succeeded, so variables are initialized and we can run the
    949         # stateless function.
--> 950         return self._stateless_fn(*args, **kwds)
    951     else:
    952       _, _, _, filtered_flat_args = \

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   3037       (graph_function,
   3038        filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 3039     return graph_function._call_flat(
   3040         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3041 

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1961         and executing_eagerly):
   1962       # No tape is watching; skip to running the function.
-> 1963       return self._build_call_outputs(self._inference_function.call(
   1964           ctx, args, cancellation_manager=cancellation_manager))
   1965     forward_backward = self._select_forward_and_backward_functions(

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    589       with _InterpolateFunctionError(self):
    590         if cancellation_manager is None:
--> 591           outputs = execute.execute(
    592               str(self.signature.name),
    593               num_outputs=self._num_outputs,

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError:  input depth must be evenly divisible by filter depth: 224 vs 3
	 [[node model/tf_vi_t_model/vit/embeddings/patch_embeddings/projection/Conv2D (defined at ProgramData\Anaconda3\lib\site-packages\transformers\models\vit\modeling_tf_vit.py:199) ]] [Op:__inference_train_function_30507]

Errors may have originated from an input operation.
Input Source operations connected to node model/tf_vi_t_model/vit/embeddings/patch_embeddings/projection/Conv2D:
 model/tf_vi_t_model/vit/embeddings/patch_embeddings/transpose (defined at ProgramData\Anaconda3\lib\site-packages\transformers\models\vit\modeling_tf_vit.py:197)

Function call stack:
train_function

Can anyone explain to me what " input depth must be evenly divisible by filter depth: 224 vs 3" means for me and how to fix?
the shapes of my train and validation data are as follows:

Train: X_train_images=(3932, 224, 224, 3), y_train_labels=(3932, 1)
Validation: X_val_images=(800, 224, 224, 3), y_val_labels=(800, 1)

It is my first time experimenting with a ViT transfer model! thank you very much. Any other advice on my model architecture is welcome too.

P.S. I have used this article as a guide to install hugging face and transformers through anaconda and run a model with my dataset of images, https://www.philschmid.de/image-classification-huggingface-transformers-keras