Slow inference for translation

Hello there!
I am new to using pretrained models. I am having 2.000.000 texts i want to backtranslate. The problem I am having is that it is really slowe (400 Hours plus). I am using an Nvidia rtx 4080 super oc. Here is the code I am using:

import torch
from transformers import MarianTokenizer, MarianMTModel
from torch.cuda.amp import autocast
from tqdm import tqdm

Setup Marian models for forward and backward translation

source_lang = ‘en’
target_lang = ‘de’
device = ‘cuda’

model_name_f = f’Helsinki-NLP/opus-mt-{source_lang}-{target_lang}’
model_name_b = f’Helsinki-NLP/opus-mt-{target_lang}-{source_lang}’

tokenizer_f = MarianTokenizer.from_pretrained(model_name_f)
model_f = MarianMTModel.from_pretrained(model_name_f).to(device)

tokenizer_b = MarianTokenizer.from_pretrained(model_name_b)
model_b = MarianMTModel.from_pretrained(model_name_b).to(device)

def translate(tokens, model, batch_size, tokenizer, desc=‘Translating’):
all_generated_tokens =
total_batches = (tokens[‘input_ids’].size(0) + batch_size - 1) // batch_size

for i in tqdm(range(0, tokens['input_ids'].size(0), batch_size), total=total_batches, desc=desc):
    batch_tokens = {key: val[i:i+batch_size] for key, val in tokens.items()}
    with autocast():
        with torch.no_grad():
            generated_tokens = model.generate(**batch_tokens)
            all_generated_tokens.append(generated_tokens)

all_generated_tokens = torch.cat(all_generated_tokens)
translated_texts = tokenizer.batch_decode(all_generated_tokens, skip_special_tokens=True)

return translated_texts

def batch_back_translate(texts, tokenizer_f, model_f, tokenizer_b, model_b, batch_size=64):
# Tokenize all original texts with progress
print(‘Tokenizing original text…’)
n_texts = len(texts)
tokens_f = tokenizer_f(texts, return_tensors=‘pt’, padding=True, truncation=True, max_length=256).to(device)
print(f’Tokenization complete for {n_texts} texts.')

translated_f = translate(tokens_f, model_f, batch_size, tokenizer_f, desc='Forward Translating')

print('Tokenizing translated text...')
tokens_b = tokenizer_b(translated_f, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)
print(f'Tokenization complete for {len(translated_f)} translated texts.')

translated_b = translate(tokens_b, model_b, batch_size, tokenizer_b, 'Back Translating')

return translated_b

emphasized text

I am sorry about the formating, I do not know how to get the code in a nice format. Hopefully you can see if anything is wrong or can be improved