Hello. Please does anyone know if there is an issue with TRL library or a way to fix it? I am trying to fine tune Mistral 7B inside Kaggle, I already did it once with the same notebook, but since yesterday I am receiving an error while importing TRL. I already restarted the kernel and did factory reset, but the error still happens.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch, wandb
from datasets import load_dataset
from trl import SFTTrainer
This is the error:
File /opt/conda/lib/python3.10/site-packages/trl/trainer/init.py:44
42 from .ppo_trainer import PPOTrainer
43 from .reward_trainer import RewardTrainer, compute_accuracy
ā> 44 from .sft_trainer import SFTTrainer
45 from .training_configs import RewardConfig
File /opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:23
21 import torch.nn as nn
22 from datasets import Dataset
ā> 23 from datasets.arrow_writer import SchemaInferenceError
24 from datasets.builder import DatasetGenerationError
25 from transformers import (
26 AutoModelForCausalLM,
27 AutoTokenizer,
(ā¦)
33 TrainingArguments,
34 )
ImportError: cannot import name āSchemaInferenceErrorā from ādatasets.arrow_writerā (/opt/conda/lib/python3.10/site-packages/datasets/arrow_writer.py)
@DrPalmiere I fixed it in Kaggle notebook installing latest version of datasets package. Seems, that by default older version of datasets package is used without SchemaInferenceError definition. Use following
Thanks, I am using this notebook too (Mistral-7B fine tuning).
I tried to update the datasets library but I got the same error. I fix the problem by creating a fresh notebook and copying all the code from the old notebook to the new one, and worked.
I tried that but Iām still getting the same error:
ImportError: cannot import name āSchemaInferenceErrorā from ādatasets.arrow_writerā (/opt/conda/lib/python3.10/site-packages/datasets/arrow_writer.py)