Training failed due to Python based feature extractor

I am trying to use the AST model for fine-tuning on my dataset and I don’t have any problems until I get to the training step.

This is the error That is being returned

***** Running training *****
  Num examples = 2
  Num Epochs = 4
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 4
  Number of trainable parameters = 86187264
KeyError                                  Traceback (most recent call last)
<ipython-input-26-97a81c88cb1e> in <module>
----> 1 train_results = trainer.train()
      3 # save tokenizer with the model
      4 trainer.save_model()
      5 trainer.log_metrics('train', train_results.metrics)

6 frames
/usr/local/lib/python3.8/dist-packages/transformers/ in __getitem__(self, item)
     84             return[item]
     85         else:
---> 86             raise KeyError("Indexing with integers is not available when using Python based feature extractors")
     88     def __getattr__(self, item: str):

KeyError: 'Indexing with integers is not available when using Python based feature extractors'

And this is my setup

model_path = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor()

def preprocess_audio(features, labels):

  inputs = feature_extractor(features, return_tensors = 'pt', sampling_rate=16000)
  inputs['label'] = labels
  return inputs

# right now, the tensor is loaded on the cpu. If a GPU is available, we want to load it onto that for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# the data train and test are pandas dataframes with a features and labels column
vision_train = preprocess_audio(train['features'].tolist(), train['labels'].tolist())
vision_test = preprocess_audio(test['features'].tolist(), test['labels'].tolist())

def collate(batch):
    return {
        'input_values': torch.stack([x['input_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])

def compute_metric(p):
    return metric.compute(
        predictions = np.argmax(p.predictions, axis = 1),
        references = p.label_ids

# arguments used by the model
args = TrainingArguments(
    output_dir = '/content/ast',
    save_steps = 100,
    learning_rate = 3e-5,
    save_total_limit = 2,
    remove_unused_columns = False,
    push_to_hub = False,
    load_best_model_at_end = True


trainer = Trainer(
    model = model,
    args = args,
    data_collator = collate,
    train_dataset = vision_train,
    eval_dataset = vision_test,
    tokenizer = feature_extractor

Happy to provide more context but I’m not really sure what the error is here :sob: help highly appreciated @stevhliu stevhliu