from transformers import MBartForConditionalGeneration, MBart50Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import Dataset, DatasetDict
from evaluate import load
import torch
import os
Configurar fallback para MPS
os.environ[“PYTORCH_ENABLE_MPS_FALLBACK”] = “1”
Verificar se MPS está disponível
device = torch.device(“mps”) if torch.backends.mps.is_available() else torch.device(“cpu”)
print(f"Usando o dispositivo: {device}")
1. Carregar o modelo e tokenizer
model_name = “facebook/mbart-large-50”
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
Configurar os idiomas
tokenizer.src_lang = “pt_XX”
tokenizer.tgt_lang = “en_XX”
2. Preparar os dados
data = {
“train”: [
{“pt”: “Olá, como você está?”, “en”: “Hello, how are you?”},
{“pt”: “Eu gosto de programar.”, “en”: “I like programming.”},
{“pt”: “O céu está azul.”, “en”: “The sky is blue.”},
],
“validation”: [
{“pt”: “Bom dia.”, “en”: “Good morning.”},
{“pt”: “Eu estou feliz.”, “en”: “I am happy.”},
],
}
Converter os dados para o formato da biblioteca datasets
train_dataset = Dataset.from_list(data[“train”])
val_dataset = Dataset.from_list(data[“validation”])
dataset = DatasetDict({“train”: train_dataset, “validation”: val_dataset})
3. Função de pré-processamento corrigida
def preprocess_function(examples):
inputs = examples[“pt”] # Obter a lista de frases em português
targets = examples[“en”] # Obter a lista de frases em inglês
model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding=“max_length”)
labels = tokenizer(targets, max_length=128, truncation=True, padding=“max_length”).input_ids
model_inputs[“labels”] = labels
return model_inputs
tokenized_datasets = dataset.map(preprocess_function, batched=True)
4. Configurar argumentos de treinamento
training_args = Seq2SeqTrainingArguments(
output_dir=“./results”, # Pasta para salvar os resultados
evaluation_strategy=“epoch”, # Avaliar ao final de cada época
learning_rate=2e-5, # Taxa de aprendizado
per_device_train_batch_size=4, # Ajustado para MPS
per_device_eval_batch_size=4, # Ajustado para MPS
save_total_limit=3, # Manter apenas 3 checkpoints
num_train_epochs=5, # Número de épocas
predict_with_generate=True, # Gerar previsões durante a validação
logging_dir=“./logs”, # Diretório para logs
logging_steps=10 # Registrar a cada 10 passos
)
5. Configurar o Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets[“train”],
eval_dataset=tokenized_datasets[“validation”],
tokenizer=tokenizer,
)
6. Treinar o modelo
trainer.train()
7. Testar o modelo em novas frases
test_sentences = [“Eu amo aprender novas linguagens.”, “O gato está dormindo.”]
for sentence in test_sentences:
# Tokenizar e enviar para o dispositivo
inputs = tokenizer(sentence, return_tensors=“pt”, max_length=128, truncation=True).to(device)
# Gerar tradução
translated_tokens = model.generate(
**inputs,
max_length=128,
num_beams=5, # Melhor qualidade com beam search
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"] # Força o idioma de destino
)
# Decodificar e exibir
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True).strip()
print(f"Original: {sentence}")
print(f"Tradução: {translation}\n")
8. Avaliação com BLEU usando evaluate
metric = load(“sacrebleu”)
Dados para avaliação
references = [[“Hello, how are you?”], [“I like programming.”], [“The sky is blue.”]]
predictions = [“Hello, how are you?”, “I like programming.”, “The sky is blue.”]
Calcular BLEU
results = metric.compute(predictions=predictions, references=references)
print(“BLEU Score:”, results[“score”])