Tensorflow in Part 2 of the course


I’m doing the second part of the course now, in particular, the chapter “The :hugs: Datasets library”. In Part 1, I was following the tensorflow option but it seems that now only the pytorch one is available (when I select tensorflow, it still shows the pytorch-based tutorial). Are you planning to release the tensorflow tutorial for Part 2 also?

Thanks in advance!

Hi Lenn! All the sections have a TensorFlow version. Chapter 5 is completely framework agnostic, that’s why you don’t see any differences between the two, but if you look at chapter 7, you will see the content is very different.

Thanks for replying sgugger!

The section “Semantic search with FAISS” in chapter 5 requires to use Pytorch as you can see on the screenshot

Hey @Lenn, sorry for the oversight on this section - I wrote that and forgot to include the TensorFlow equivalent code :man_facepalming:

We’ll patch a fix by the end of the week, but in the meantime you can use this code snippet to generate the embeddings in TensorFlow (ignore the Colab cell with model.to(device)):

from transformers import AutoTokenizer, TFAutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Load TensorFlow model from PyTorch checkpoint :)
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="tf"
    encoded_input = {k: v for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

# Compute embeddings
embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).numpy()[0]}

Hope that helps!

cc @Rocketknight1 for visibility

1 Like

Thank you very much for answering this and for the course, in general!

As for me, it was a good reason to start learning Pytorch :grin: