mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
huggingface[patch]: hide client field in HuggingFaceEmbeddings (#27522)
This commit is contained in:
parent
380449a7a9
commit
0640cbf2f1
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional # type: ignore[import-not-found]
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
@ -26,7 +26,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any = None #: :meta private:
|
|
||||||
model_name: str = DEFAULT_MODEL_NAME
|
model_name: str = DEFAULT_MODEL_NAME
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
cache_folder: Optional[str] = None
|
cache_folder: Optional[str] = None
|
||||||
@ -57,7 +56,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install sentence-transformers`."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
self.client = sentence_transformers.SentenceTransformer(
|
self._client = sentence_transformers.SentenceTransformer(
|
||||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -79,12 +78,20 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||||
if self.multi_process:
|
if self.multi_process:
|
||||||
pool = self.client.start_multi_process_pool()
|
pool = self._client.start_multi_process_pool()
|
||||||
embeddings = self.client.encode_multi_process(texts, pool)
|
embeddings = self._client.encode_multi_process(texts, pool)
|
||||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||||
else:
|
else:
|
||||||
embeddings = self.client.encode(
|
embeddings = self._client.encode(
|
||||||
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
texts,
|
||||||
|
show_progress_bar=self.show_progress,
|
||||||
|
**self.encode_kwargs, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(embeddings, list):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected embeddings to be a Tensor or a numpy array, "
|
||||||
|
"got a list instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
Loading…
Reference in New Issue
Block a user