Hi,
I am trying to fine tune the VITS Model but then this error raised
Here is the code:
import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForCausalLM
from custom_dataset import CustomDataset
# Paths and constants
train_filelist = "C:\\TTS\\vits\\Dataset\\train_filelist.csv"
val_filelist = "C:\\TTS\\vits\\Dataset\\val_filelist.csv"
model_name = "kakao-enterprise/vits-ljs" # Pretrained VITS model
use_auth_token = "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" # Replace with your actual Hugging Face token
# Load dataset
train_dataset = CustomDataset(train_filelist)
eval_dataset = CustomDataset(val_filelist)
# Initialize model and tokenizer with token authentication
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=use_auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=use_auth_token)
# Define a proper data collator
class CustomDataCollator:
def __call__(self, batch):
# Debug prints
print("Batch received by data_collator:", batch)
# Extract input_values and labels
audio_data = [item['input_values'] for item in batch]
text_data = [item['labels'] for item in batch]
# Convert to tensors
audio_tensor = torch.stack(audio_data)
text_tensor = torch.stack(text_data)
return {'input_values': audio_tensor, 'labels': text_tensor}
data_collator = CustomDataCollator()
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=1e-4,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
logging_dir="./logs",
save_steps=500,
logging_steps=100,
eval_steps=500,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
# Start training
trainer.train()
raise ValueError(
ValueError: Unrecognized configuration class <class ‘transformers.models.vits.configuration_vits.VitsConfig’> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, ElectraConfig, ErnieConfig, FalconConfig, FuyuConfig, GemmaConfig, Gemma2Config, GitConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, JambaConfig, JetMoeConfig, LlamaConfig, MambaConfig, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MptConfig, MusicgenConfig, MusicgenMelodyConfig, MvpConfig, OlmoConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, Phi3Config, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, Qwen2MoeConfig, RecurrentGemmaConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, Speech2Text2Config, StableLmConfig, Starcoder2Config, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig
How it will be solved?
Thanks