From 239dc108527849e94f447ab4cc2c6e51fc1aab1f Mon Sep 17 00:00:00 2001 From: Zzz233 Date: Tue, 25 Apr 2023 08:20:08 +0800 Subject: [PATCH] ES similarity_search_with_score() and metadata filter (#3046) Add similarity_search_with_score() to ElasticVectorSearch, add metadata filter to both similarity_search() and similarity_search_with_score() --- .../vectorstores/elastic_vector_search.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 17af42c66ad..dc11a84269e 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid from abc import ABC -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -20,10 +20,15 @@ def _default_text_mapping(dim: int) -> Dict: } -def _default_script_query(query_vector: List[float]) -> Dict: +def _default_script_query(query_vector: List[float], filter: Optional[dict]) -> Dict: + if filter: + ((key, value),) = filter.items() + filter = {"match": {f"metadata.{key}.keyword": f"{value}"}} + else: + filter = {"match_all": {}} return { "script_score": { - "query": {"match_all": {}}, + "query": filter, "script": { "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", "params": {"query_vector": query_vector}, @@ -187,7 +192,7 @@ class ElasticVectorSearch(VectorStore, ABC): return ids def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any ) -> List[Document]: """Return docs most similar to query. @@ -198,15 +203,36 @@ class ElasticVectorSearch(VectorStore, ABC): Returns: List of Documents most similar to the query. """ - embedding = self.embedding.embed_query(query) - script_query = _default_script_query(embedding) - response = self.client.search(index=self.index_name, query=script_query, size=k) - hits = [hit["_source"] for hit in response["hits"]["hits"]] - documents = [ - Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits - ] + docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + documents = [d[0] for d in docs_and_scores] return documents + def similarity_search_with_score( + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding.embed_query(query) + script_query = _default_script_query(embedding, filter) + response = self.client.search(index=self.index_name, query=script_query, size=k) + hits = [hit for hit in response["hits"]["hits"]] + docs_and_scores = [ + ( + Document( + page_content=hit["_source"]["text"], + metadata=hit["_source"]["metadata"], + ), + hit["_score"], + ) + for hit in hits + ] + return docs_and_scores + @classmethod def from_texts( cls,