mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
Now `encode_kwargs` used for both for documents and queries and this leads to wrong embeddings. E. g.: ```python model_kwargs = {"device": "cuda", "trust_remote_code": True} encode_kwargs = {"normalize_embeddings": False, "prompt_name": "s2p_query"} model = HuggingFaceEmbeddings( model_name="dunzhang/stella_en_400M_v5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) query_embedding = np.array( model.embed_query("What are some ways to reduce stress?",) ) document_embedding = np.array( model.embed_documents( [ "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.", "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.", ] ) ) print(model._client.similarity(query_embedding, document_embedding)) # output: tensor([[0.8421, 0.3317]], dtype=torch.float64) ``` But from the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5#sentence-transformers) expexted like this: ```python model_kwargs = {"device": "cuda", "trust_remote_code": True} encode_kwargs = {"normalize_embeddings": False} query_encode_kwargs = {"normalize_embeddings": False, "prompt_name": "s2p_query"} model = HuggingFaceEmbeddings( model_name="dunzhang/stella_en_400M_v5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, query_encode_kwargs=query_encode_kwargs, ) query_embedding = np.array( model.embed_query("What are some ways to reduce stress?", ) ) document_embedding = np.array( model.embed_documents( [ "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.", "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.", ] ) ) print(model._client.similarity(query_embedding, document_embedding)) # tensor([[0.8398, 0.2990]], dtype=torch.float64) ```
136 lines
5.2 KiB
Python
136 lines
5.2 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
|
|
|
|
|
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|
"""HuggingFace sentence_transformers embedding models.
|
|
|
|
To use, you should have the ``sentence_transformers`` python package installed.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
|
model_kwargs = {'device': 'cpu'}
|
|
encode_kwargs = {'normalize_embeddings': False}
|
|
hf = HuggingFaceEmbeddings(
|
|
model_name=model_name,
|
|
model_kwargs=model_kwargs,
|
|
encode_kwargs=encode_kwargs
|
|
)
|
|
"""
|
|
|
|
model_name: str = DEFAULT_MODEL_NAME
|
|
"""Model name to use."""
|
|
cache_folder: Optional[str] = None
|
|
"""Path to store models.
|
|
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
|
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
|
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
|
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
|
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
|
`precision`, `normalize_embeddings`, and more.
|
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
|
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Keyword arguments to pass when calling the `encode` method for the query of
|
|
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
|
`precision`, `normalize_embeddings`, and more.
|
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
|
multi_process: bool = False
|
|
"""Run encode() on multiple GPUs."""
|
|
show_progress: bool = False
|
|
"""Whether to show a progress bar."""
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
"""Initialize the sentence_transformer."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import sentence_transformers # type: ignore[import]
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"Could not import sentence_transformers python package. "
|
|
"Please install it with `pip install sentence-transformers`."
|
|
) from exc
|
|
|
|
self._client = sentence_transformers.SentenceTransformer(
|
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
|
)
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
protected_namespaces=(),
|
|
)
|
|
|
|
def _embed(
|
|
self, texts: list[str], encode_kwargs: Dict[str, Any]
|
|
) -> List[List[float]]:
|
|
"""
|
|
Embed a text using the HuggingFace transformer model.
|
|
|
|
Args:
|
|
texts: The list of texts to embed.
|
|
encode_kwargs: Keyword arguments to pass when calling the
|
|
`encode` method for the documents of the SentenceTransformer
|
|
encode method.
|
|
|
|
Returns:
|
|
List of embeddings, one for each text.
|
|
"""
|
|
import sentence_transformers # type: ignore[import]
|
|
|
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
if self.multi_process:
|
|
pool = self._client.start_multi_process_pool()
|
|
embeddings = self._client.encode_multi_process(texts, pool)
|
|
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
|
else:
|
|
embeddings = self._client.encode(
|
|
texts,
|
|
show_progress_bar=self.show_progress,
|
|
**encode_kwargs, # type: ignore
|
|
)
|
|
|
|
if isinstance(embeddings, list):
|
|
raise TypeError(
|
|
"Expected embeddings to be a Tensor or a numpy array, "
|
|
"got a list instead."
|
|
)
|
|
|
|
return embeddings.tolist()
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
|
|
|
Args:
|
|
texts: The list of texts to embed.
|
|
|
|
Returns:
|
|
List of embeddings, one for each text.
|
|
"""
|
|
return self._embed(texts, self.encode_kwargs)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Compute query embeddings using a HuggingFace transformer model.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
|
|
Returns:
|
|
Embeddings for the text.
|
|
"""
|
|
embed_kwargs = (
|
|
self.query_encode_kwargs
|
|
if len(self.query_encode_kwargs) > 0
|
|
else self.encode_kwargs
|
|
)
|
|
return self._embed([text], embed_kwargs)[0]
|