Inference result not aligned with local version of same model and revision

Hello,
I am trying to run this embedding model “sentence-transformers/LaBSE” with revision=“836121a0533e5664b21c7aacc5d22951f2b8b25b” on the Inference Endpoints.

I have a result, but the embeddings numbers are different from the local execution. And not even correlated using cosine similarity.

Any idea what’s going on ?

from abc import ABC, abstractmethod
import numpy as np
import requests
from sentence_transformers import SentenceTransformer
from sbw_fiabilis.logger import get_logger, set_level
import os
from dotenv import load_dotenv

logger = get_logger()


class EmbeddingInterface(ABC):
    """Interface abstraite pour les services d'embedding."""
    
    @abstractmethod
    def encode(self, texts, batch_size=None, show_progress_bar=False):
        pass


class LocalEmbeddingService(EmbeddingInterface):
    """Implémentation locale utilisant SentenceTransformer."""
    
    def __init__(self):
        WORKING_DIR = os.getenv("WORKING_DIR", os.path.join(os.path.dirname(__file__), "../../data/working_dir"))
        HF_HOME = os.path.join(WORKING_DIR, ".hf")
        os.environ["HF_HOME"] = HF_HOME

        self.model = SentenceTransformer("sentence-transformers/LaBSE", revision="836121a0533e5664b21c7aacc5d22951f2b8b25b", cache_folder=HF_HOME)
        logger.info(f"LocalEmbeddingService configuré")
    
    def encode(self, texts, batch_size=32, show_progress_bar=False):
        return self.model.encode(texts, batch_size=batch_size, show_progress_bar=show_progress_bar)


class APIEmbeddingService(EmbeddingInterface):
    """Implémentation utilisant l'API Hugging Face."""
    
    def __init__(self):
        self.api_url = os.getenv("EMBEDDING_API_URL")
        self.api_key = os.getenv("EMBEDDING_API_KEY")
        if not self.api_url or not self.api_key:
            raise ValueError("EMBEDDING_API_URL et EMBEDDING_API_KEY doivent être définis")
        self.headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        logger.info(f"ApiEmbeddingService configuré")
    
    def _query_api(self, payload):
        try:
            response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Erreur lors de la requête API: {e}")
            raise
    
    def encode(self, texts, batch_size=32, show_progress_bar=False):
        if not texts:
            return np.array([])
        
        all_embeddings = []
        total_texts = len(texts)
        
        logger.info(f"Encodage via API: {total_texts} textes en lots de {batch_size}")
        
        for i in range(0, total_texts, batch_size):
            batch = texts[i:i + batch_size]
            
            payload = {
                "inputs": batch,
                "parameters": {}
            }
            
            response = self._query_api(payload)
            
            # Gestion des différents formats de réponse API
            if isinstance(response, list):
                batch_embeddings = response
            elif isinstance(response, dict) and "embeddings" in response:
                batch_embeddings = response["embeddings"]
            else:
                raise ValueError(f"Format de réponse API inattendu: {type(response)}")
            
            all_embeddings.extend(batch_embeddings)
            
            logger.info(f"  Lot traité: {min(i + batch_size, total_texts)}/{total_texts}")
        
        return all_embeddings





def test():
    logger = get_logger()
    set_level("DEBUG")

    load_dotenv()

    texts = ["toto", "tata"]

    service = LocalEmbeddingService()
    embeddings = service.encode(texts)
    logger.info(embeddings[0][:5])
    logger.info(embeddings[1][:5])

    service = APIEmbeddingService()
    embeddings = service.encode(texts)
    logger.info(embeddings[0][:5])
    logger.info(embeddings[1][:5])

if __name__ == "__main__":
    test()
1 Like