This gist lists all newly added three letter codes and the consituent languages (often other three letter codes) https://gist.github.com/sshleifer/e79fbbabe0fab3da519fd39edffee4d2
If you know your langâs ISO-639-3 code, you can cmd-f for it in that file to see which models will support it.
Backtranslation Snippet
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
mname_fwd = 'Helsinki-NLP/opus-mt-en-ceb' #ceb=cebuano https://en.wikipedia.org/wiki/Cebuano_language
mname_bwd = 'Helsinki-NLP/opus-mt-ceb-en'
src_text = ['I am a small frog with tiny legs.']
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
fwd = AutoModelForSeq2SeqLM.from_pretrained(mname_fwd).to(torch_device)
fwd_tok = AutoTokenizer.from_pretrained(mname_fwd)
bwd_tok = AutoTokenizer.from_pretrained(mname_bwd)
bwd = AutoModelForSeq2SeqLM.from_pretrained(mname_bwd).to(torch_device)
if torch_device == 'cuda':
fwd = fwd.half()
bwd = bwd.half()
fwd_batch = fwd_tok(src_text, return_tensors='pt').to(torch_device)
translated = fwd.generate(**fwd_batch, num_beams=2)
translated_txt = fwd_tok.batch_decode(translated, skip_special_tokens=True)
bwd_batch = bwd_tok(translated_txt, return_tensors='pt').to(torch_device)
backtranslated = bwd.generate(**bwd_batch, num_beams=2)
result = bwd_tok.batch_decode(backtranslated, skip_special_tokens=True)
# ['I am a small toad with small feet.']
3 Likes
!Wao! This looks great! I have one doubt, in my case translation and back translation takes about 1 sec, I have tried several models and different example scripts of how to do translation and still slow, I have an RTX 2080 ti, please, I would like to know if I am missing something.
(1) Try bigger batches
(2) feel free to send your code + donât count any of the lines until fwd_batch
in your timing.
Hi thanks for the quick answer, Iâll post then the code:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from timeit import default_timer as timer
src_text = ['Senior Level Software Engineer', 'Instrumental Software Technologies', 'Saratoga Springs, NY']
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch_device)
mname_fwd = 'Helsinki-NLP/opus-mt-en-es'
fwd = AutoModelForSeq2SeqLM.from_pretrained(mname_fwd).to(torch_device)
fwd_tok = AutoTokenizer.from_pretrained(mname_fwd)
if torch_device == 'cuda':
fwd = fwd.half()
star = timer()
fwd_batch = fwd_tok(src_text, return_tensors='pt', padding=True, truncation=True).to(torch_device)
translated = fwd.generate(**fwd_batch)
translated_txt = fwd_tok.batch_decode(translated, skip_special_tokens=True)
end = timer()
print(translated_txt)
print(end-star)
cuda
[âIngeniero de software de nivel superiorâ, âTecnologĂas de software instrumentalesâ, âSaratoga Springs, NYâ]
0.5867295530042611
587 ms/3 samples = 196ms/sample
If you pass in src_text of length 128, you should get a similar runtime and therefore lower ms/sample.