mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
Improve query prompt of BGE embeddings (#8908)
Replace this comment with: - Description: Improved query of BGE embeddings after talking with the devs of BGE embeddings , - Dependencies: any dependencies required for this change, - Tag maintainer: @hwchase17 , - Twitter handle: @ManabChetia3 --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
c5c0735fc4
commit
4a07fba9f0
@ -11,12 +11,10 @@ DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
|
|||||||
DEFAULT_QUERY_INSTRUCTION = (
|
DEFAULT_QUERY_INSTRUCTION = (
|
||||||
"Represent the question for retrieving supporting documents: "
|
"Represent the question for retrieving supporting documents: "
|
||||||
)
|
)
|
||||||
DEFAULT_EMBED_BGE_INSTRUCTION = (
|
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
|
||||||
"Represent this sentence for searching relevant passages: "
|
|
||||||
)
|
|
||||||
DEFAULT_QUERY_BGE_INSTRUCTION = (
|
|
||||||
"Represent this question for searching relevant passages: "
|
"Represent this question for searching relevant passages: "
|
||||||
)
|
)
|
||||||
|
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||||
@ -190,7 +188,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
model_name = "BAAI/bge-large-en"
|
model_name = "BAAI/bge-large-en"
|
||||||
model_kwargs = {'device': 'cpu'}
|
model_kwargs = {'device': 'cpu'}
|
||||||
encode_kwargs = {'normalize_embeddings': False}
|
encode_kwargs = {'normalize_embeddings': True}
|
||||||
hf = HuggingFaceBgeEmbeddings(
|
hf = HuggingFaceBgeEmbeddings(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
@ -208,9 +206,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Key word arguments to pass when calling the `encode` method of the model."""
|
"""Key word arguments to pass when calling the `encode` method of the model."""
|
||||||
embed_instruction: str = DEFAULT_EMBED_BGE_INSTRUCTION
|
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
|
||||||
"""Instruction to use for embedding documents."""
|
|
||||||
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION
|
|
||||||
"""Instruction to use for embedding query."""
|
"""Instruction to use for embedding query."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
@ -228,6 +224,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
self.client = sentence_transformers.SentenceTransformer(
|
self.client = sentence_transformers.SentenceTransformer(
|
||||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||||
)
|
)
|
||||||
|
if "-zh" in self.model_name:
|
||||||
|
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -243,8 +241,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
texts = [t.replace("\n", " ") for t in texts]
|
||||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -256,6 +254,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
instruction_pair = [self.query_instruction, text]
|
text = text.replace("\n", " ")
|
||||||
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
|
embedding = self.client.encode(
|
||||||
|
self.query_instruction + text, **self.encode_kwargs
|
||||||
|
)
|
||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
|
Loading…
Reference in New Issue
Block a user