Problem with transformer optimizer assertion error

def preprocess_function(example):
    text = example["text"]
    return tokenizer(
        text,
        return_tensors="pt",
        truncation=True,  # Truncate longer sequences
        padding=True,  # Pad shorter sequences
        max_length=512
    )



# Map the preprocess_function to the dataset
tokenized_data = data.map(preprocess_function, batched=True)

# SciBERT uses <s> for both padding and EOS tokens
tokenizer.pad_token = tokenizer.eos_token = tokenizer.bos_token  # <s>

# Data collator compatible with SciBERT's output
data_collator = transformers.DataCollatorForLanguageModeling(
    tokenizer, mlm=False, return_tensors="pt"
)

lr = 2e-4
batch_size = 4
num_epochs = 10

# define training arguments
training_args = transformers.TrainingArguments(
    output_dir= "workspace",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    fp16=True,
    optim="paged_adamw_8bit",

)

# Data collator compatible with SciBERT's output
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

tokenized_train_dataset = tokenized_data["train"]  # Access the "train" split
#tokenized_val_dataset = tokenized_data["validation"]  # Access the "validation" split (if it exists)
tokenized_val_dataset = None
if "validation" in tokenized_data:
    tokenized_val_dataset = tokenized_data["validation"]

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,  # Use tokenized_val_dataset for evaluation
    args=training_args,
    data_collator=data_collator
)


# train model
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train() AssertionError 

...AssertionError                            Traceback (most recent call last)
Cell In[18], line 21
     19 # train model
     20 model.config.use_cache = False # silence the warnings. Please re-enable for inference!
---> 21 trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1780, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1778         hf_hub_utils.enable_progress_bars()
   1779 else:
-> 1780     return inner_training_loop(
   1781         args=args,
   1782         resume_from_checkpoint=resume_from_checkpoint,
   1783         trial=trial,
   1784         ignore_keys_for_eval=ignore_keys_for_eval,
   1785     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2181, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2178         grad_norm = _grad_norm
   2180 # Optimizer step
-> 2181 self.optimizer.step()
   2182 optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
   2183 if optimizer_was_run:
   2184     # Delay optimizer scheduling until metrics are generated

File /opt/conda/lib/python3.10/site-packages/accelerate/optimizer.py:136, in AcceleratedOptimizer.step(self, closure)
    133 if self.scaler is not None:
    134     self.optimizer.step = self._optimizer_patched_step_method
--> 136     self.scaler.step(self.optimizer, closure)
    137     self.scaler.update()
    139     if not self._accelerate_step_called:
    140         # If the optimizer step was skipped, gradient overflow was detected.

File /opt/conda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:448, in GradScaler.step(self, optimizer, *args, **kwargs)
    445 if optimizer_state["stage"] is OptState.READY:
    446     self.unscale_(optimizer)
--> 448 assert (
    449     len(optimizer_state["found_inf_per_device"]) > 0
    450 ), "No inf checks were recorded for this optimizer."
    452 retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
    454 optimizer_state["stage"] = OptState.STEPPED

AssertionError: No inf checks were recorded for this optimizer.

[450](optimizer_state["stage"] = OptState.STEPPED AssertionError: No inf checks were recorded for this optimizer.

I changed the optimizer and set fp16 = False, but still the same error!