Unsupported value type BatchEncoding returned by IteratorSpec._serialize

Hi all!

I’m having a go at fine tuning BERT for a regression problem (given a passage of text, predict it’s readability score) as a part of a Kaggle competition.

To do so I’m doing the following:

1. Loading BERT tokenizer and applying it to the dataset

from transformers import BertTokenizer, TFBertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# df_raw is the training dataset
tokens = tokenizer(list(df_raw['excerpt']), padding='max_length', truncation=True, return_tensors="np")

2. Adding the target column and creating a TensorFlow dataset

tokens_w_labels = tokens.copy()
tokens_w_labels['target'] = df_raw['target']

tokens_dataset = tf.data.Dataset.from_tensor_slices(
    tokens_w_labels
)

3. Loading a sequence classification model, and attempting to fine tune

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification

# num_labels=1 should produce regression (numerical) output
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=1)

# Compile with relevant metrics / loss
model.compile(
    optimizer='adam',
    loss='mean_squared_error',
    metrics=['mean_absolute_error', 'mean_squared_error'],
)

# Removed validation data for time being
model.fit(
    tokens_dataset,
    batch_size=8,
    epochs=3
)

It’s at this step that I get an error: Unsupported value type BatchEncoding returned by IteratorSpec._serialize. I’ve tried a few different setups and I can’t figure out where the issue is.

The specific part of TensorFlow that the error code comes from is here.

Any pointers on what’s going wrong here?

1 Like

Eventually got this working - appears that the error was in step 2, where I was combining the target column with the tokenized labels before creating the dataset.

I also needed to turn the tokenized labels into a dict.

# This version works
train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(tokens_w_labels),
    df_raw['target'].values
))

replace ‘tokens_dataset’ by ‘tokens_dataset[‘input_ids’]’