huggingface[patch]: hide client field in HuggingFaceEmbeddings (#27522)

This commit is contained in:
Vadym Barda 2024-10-21 17:37:07 -04:00 committed by GitHub
parent 380449a7a9
commit 0640cbf2f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional # type: ignore[import-not-found]
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field
@ -26,7 +26,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
)
"""
client: Any = None #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
cache_folder: Optional[str] = None
@ -57,7 +56,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install sentence-transformers`."
) from exc
self.client = sentence_transformers.SentenceTransformer(
self._client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
@ -79,12 +78,20 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
pool = self._client.start_multi_process_pool()
embeddings = self._client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
embeddings = self._client.encode(
texts,
show_progress_bar=self.show_progress,
**self.encode_kwargs, # type: ignore
)
if isinstance(embeddings, list):
raise TypeError(
"Expected embeddings to be a Tensor or a numpy array, "
"got a list instead."
)
return embeddings.tolist()