mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community[minor]: implement huggingface show_progress consistently (#22682)
- **Description:** This implements `show_progress` more consistently (i.e. it is also added to the `HuggingFaceBgeEmbeddings` object). - **Issue:** This implements `show_progress` more consistently in the embeddings huggingface classes. Previously this could have been set via `encode_kwargs`. - **Dependencies:** None - **Twitter handle:** @jonzeolla
This commit is contained in:
parent
74e705250f
commit
32ba8cfab0
@ -1,7 +1,8 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core._api import deprecated, warn_deprecated
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr
|
||||
|
||||
@ -154,6 +155,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
"""Instruction to use for embedding documents."""
|
||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||
"""Instruction to use for embedding query."""
|
||||
show_progress: bool = False
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
@ -167,6 +170,20 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
except ImportError as e:
|
||||
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
||||
|
||||
if "show_progress_bar" in self.encode_kwargs:
|
||||
warn_deprecated(
|
||||
since="0.2.5",
|
||||
removal="0.4.0",
|
||||
name="encode_kwargs['show_progress_bar']",
|
||||
alternative=f"the show_progress method on {self.__class__.__name__}",
|
||||
)
|
||||
if self.show_progress:
|
||||
warnings.warn(
|
||||
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
|
||||
"encode_kwargs['show_progress_bar'] takes precedence"
|
||||
)
|
||||
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -182,7 +199,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
||||
embeddings = self.client.encode(
|
||||
instruction_pairs,
|
||||
show_progress_bar=self.show_progress,
|
||||
**self.encode_kwargs,
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -195,7 +216,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = [self.query_instruction, text]
|
||||
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
|
||||
embedding = self.client.encode(
|
||||
[instruction_pair],
|
||||
show_progress_bar=self.show_progress,
|
||||
**self.encode_kwargs,
|
||||
)[0]
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
@ -252,6 +277,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
"""Instruction to use for embedding query."""
|
||||
embed_instruction: str = ""
|
||||
"""Instruction to use for embedding document."""
|
||||
show_progress: bool = False
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
@ -268,9 +295,24 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
|
||||
if "-zh" in self.model_name:
|
||||
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||
|
||||
if "show_progress_bar" in self.encode_kwargs:
|
||||
warn_deprecated(
|
||||
since="0.2.5",
|
||||
removal="0.4.0",
|
||||
name="encode_kwargs['show_progress_bar']",
|
||||
alternative=f"the show_progress method on {self.__class__.__name__}",
|
||||
)
|
||||
if self.show_progress:
|
||||
warnings.warn(
|
||||
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
|
||||
"encode_kwargs['show_progress_bar'] takes precedence"
|
||||
)
|
||||
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -286,7 +328,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
texts = [self.embed_instruction + t.replace("\n", " ") for t in texts]
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
embeddings = self.client.encode(
|
||||
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -300,7 +344,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embedding = self.client.encode(
|
||||
self.query_instruction + text, **self.encode_kwargs
|
||||
self.query_instruction + text,
|
||||
show_progress_bar=self.show_progress,
|
||||
**self.encode_kwargs,
|
||||
)
|
||||
return embedding.tolist()
|
||||
|
||||
@ -353,7 +399,9 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
||||
from langchain_community.embeddings import (
|
||||
HuggingFaceInferenceAPIEmbeddings,
|
||||
)
|
||||
|
||||
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key="your_api_key",
|
||||
|
Loading…
Reference in New Issue
Block a user