diff --git a/libs/langchain/langchain/embeddings/huggingface.py b/libs/langchain/langchain/embeddings/huggingface.py index afbd0cd8064..6854ab4ce08 100644 --- a/libs/langchain/langchain/embeddings/huggingface.py +++ b/libs/langchain/langchain/embeddings/huggingface.py @@ -11,12 +11,10 @@ DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " DEFAULT_QUERY_INSTRUCTION = ( "Represent the question for retrieving supporting documents: " ) -DEFAULT_EMBED_BGE_INSTRUCTION = ( - "Represent this sentence for searching relevant passages: " -) -DEFAULT_QUERY_BGE_INSTRUCTION = ( +DEFAULT_QUERY_BGE_INSTRUCTION_EN = ( "Represent this question for searching relevant passages: " ) +DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:" class HuggingFaceEmbeddings(BaseModel, Embeddings): @@ -190,7 +188,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): model_name = "BAAI/bge-large-en" model_kwargs = {'device': 'cpu'} - encode_kwargs = {'normalize_embeddings': False} + encode_kwargs = {'normalize_embeddings': True} hf = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, @@ -208,9 +206,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): """Key word arguments to pass to the model.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) """Key word arguments to pass when calling the `encode` method of the model.""" - embed_instruction: str = DEFAULT_EMBED_BGE_INSTRUCTION - """Instruction to use for embedding documents.""" - query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION + query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN """Instruction to use for embedding query.""" def __init__(self, **kwargs: Any): @@ -228,6 +224,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): self.client = sentence_transformers.SentenceTransformer( self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) + if "-zh" in self.model_name: + self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH class Config: """Configuration for this pydantic object.""" @@ -243,8 +241,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - instruction_pairs = [[self.embed_instruction, text] for text in texts] - embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs) +texts = [t.replace("\n", " ") for t in texts] + embeddings = self.client.encode(texts, **self.encode_kwargs) return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -256,6 +254,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - instruction_pair = [self.query_instruction, text] - embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0] + text = text.replace("\n", " ") + embedding = self.client.encode( + self.query_instruction + text, **self.encode_kwargs + ) return embedding.tolist()