mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community[minor]: implement huggingface show_progress consistently (#22682)
- **Description:** This implements `show_progress` more consistently (i.e. it is also added to the `HuggingFaceBgeEmbeddings` object). - **Issue:** This implements `show_progress` more consistently in the embeddings huggingface classes. Previously this could have been set via `encode_kwargs`. - **Dependencies:** None - **Twitter handle:** @jonzeolla
This commit is contained in:
parent
74e705250f
commit
32ba8cfab0
@ -1,7 +1,8 @@
|
|||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated, warn_deprecated
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr
|
||||||
|
|
||||||
@ -154,6 +155,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Instruction to use for embedding documents."""
|
"""Instruction to use for embedding documents."""
|
||||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||||
"""Instruction to use for embedding query."""
|
"""Instruction to use for embedding query."""
|
||||||
|
show_progress: bool = False
|
||||||
|
"""Whether to show a progress bar."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
"""Initialize the sentence_transformer."""
|
"""Initialize the sentence_transformer."""
|
||||||
@ -167,6 +170,20 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
||||||
|
|
||||||
|
if "show_progress_bar" in self.encode_kwargs:
|
||||||
|
warn_deprecated(
|
||||||
|
since="0.2.5",
|
||||||
|
removal="0.4.0",
|
||||||
|
name="encode_kwargs['show_progress_bar']",
|
||||||
|
alternative=f"the show_progress method on {self.__class__.__name__}",
|
||||||
|
)
|
||||||
|
if self.show_progress:
|
||||||
|
warnings.warn(
|
||||||
|
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
|
||||||
|
"encode_kwargs['show_progress_bar'] takes precedence"
|
||||||
|
)
|
||||||
|
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -182,7 +199,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
||||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
embeddings = self.client.encode(
|
||||||
|
instruction_pairs,
|
||||||
|
show_progress_bar=self.show_progress,
|
||||||
|
**self.encode_kwargs,
|
||||||
|
)
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -195,7 +216,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
instruction_pair = [self.query_instruction, text]
|
instruction_pair = [self.query_instruction, text]
|
||||||
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
|
embedding = self.client.encode(
|
||||||
|
[instruction_pair],
|
||||||
|
show_progress_bar=self.show_progress,
|
||||||
|
**self.encode_kwargs,
|
||||||
|
)[0]
|
||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
|
|
||||||
|
|
||||||
@ -252,6 +277,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Instruction to use for embedding query."""
|
"""Instruction to use for embedding query."""
|
||||||
embed_instruction: str = ""
|
embed_instruction: str = ""
|
||||||
"""Instruction to use for embedding document."""
|
"""Instruction to use for embedding document."""
|
||||||
|
show_progress: bool = False
|
||||||
|
"""Whether to show a progress bar."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
"""Initialize the sentence_transformer."""
|
"""Initialize the sentence_transformer."""
|
||||||
@ -268,9 +295,24 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
self.client = sentence_transformers.SentenceTransformer(
|
self.client = sentence_transformers.SentenceTransformer(
|
||||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if "-zh" in self.model_name:
|
if "-zh" in self.model_name:
|
||||||
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||||
|
|
||||||
|
if "show_progress_bar" in self.encode_kwargs:
|
||||||
|
warn_deprecated(
|
||||||
|
since="0.2.5",
|
||||||
|
removal="0.4.0",
|
||||||
|
name="encode_kwargs['show_progress_bar']",
|
||||||
|
alternative=f"the show_progress method on {self.__class__.__name__}",
|
||||||
|
)
|
||||||
|
if self.show_progress:
|
||||||
|
warnings.warn(
|
||||||
|
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
|
||||||
|
"encode_kwargs['show_progress_bar'] takes precedence"
|
||||||
|
)
|
||||||
|
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -286,7 +328,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
texts = [self.embed_instruction + t.replace("\n", " ") for t in texts]
|
texts = [self.embed_instruction + t.replace("\n", " ") for t in texts]
|
||||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
embeddings = self.client.encode(
|
||||||
|
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
||||||
|
)
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -300,7 +344,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
embedding = self.client.encode(
|
embedding = self.client.encode(
|
||||||
self.query_instruction + text, **self.encode_kwargs
|
self.query_instruction + text,
|
||||||
|
show_progress_bar=self.show_progress,
|
||||||
|
**self.encode_kwargs,
|
||||||
)
|
)
|
||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
|
|
||||||
@ -353,7 +399,9 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
from langchain_community.embeddings import (
|
||||||
|
HuggingFaceInferenceAPIEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
||||||
api_key="your_api_key",
|
api_key="your_api_key",
|
||||||
|
Loading…
Reference in New Issue
Block a user