Error finetuning wav2vec2-xls-r-300m on kaggle TPU

Here is my current code:
from transformers import Wav2Vec2ForCTC
`
model = Wav2Vec2ForCTC.from_pretrained(
“facebook/wav2vec2-xls-r-300m”,
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.0,
layerdrop=0.1,
ctc_loss_reduction=“mean”,
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)

from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir=run_config[“hub_repo”],
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
gradient_accumulation_steps=1,
evaluation_strategy=“steps”,
num_train_epochs=30,
gradient_checkpointing=True,
torch_compile=True,
bf16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
dataloader_num_workers=2,
tpu_num_cores=8,
)
trainer.train()
`

Error

`
KeyError Traceback (most recent call last)
Cell In[67], line 1
----> 1 trainer.train()

File /usr/local/lib/python3.8/site-packages/transformers/trainer.py:1547, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1544 try:
1545 # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1546 hf_hub_utils.disable_progress_bars()
→ 1547 return inner_training_loop(
1548 args=args,
1549 resume_from_checkpoint=resume_from_checkpoint,
1550 trial=trial,
1551 ignore_keys_for_eval=ignore_keys_for_eval,
1552 )
1553 finally:
1554 hf_hub_utils.enable_progress_bars()

File /usr/local/lib/python3.8/site-packages/transformers/trainer.py:1685, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1683 model = self.accelerator.prepare(self.model)
1684 else:
→ 1685 model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
1686 else:
1687 # to handle cases wherein we pass “DummyScheduler” such as when it is specified in DeepSpeed config.
1688 model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
1689 self.model, self.optimizer, self.lr_scheduler
1690 )

File /usr/local/lib/python3.8/site-packages/accelerate/accelerator.py:1293, in Accelerator.prepare(self, device_placement, *args)
1291 new_named_params = self._get_named_parameters(*result)
1292 # 3. building a map from the first to the second
→ 1293 mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1294 # 4. using that map to update the parameters of the optimizer
1295 for obj in result:

File /usr/local/lib/python3.8/site-packages/accelerate/accelerator.py:1293, in (.0)
1291 new_named_params = self._get_named_parameters(*result)
1292 # 3. building a map from the first to the second
→ 1293 mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1294 # 4. using that map to update the parameters of the optimizer
1295 for obj in result:

KeyError: ‘wav2vec2.feature_extractor.conv_layers.0.conv.weight’
`

Edit:

This behaviour occurs when using torch_compile=True,

Code runs without it, but it is extremely slow (0.01it/s) and memory usage increase overtime (150GB+)