diff --git a/libs/langchain/langchain/vectorstores/elasticsearch.py b/libs/langchain/langchain/vectorstores/elasticsearch.py index 4bf2d4c47e5..3210cf4f9ea 100644 --- a/libs/langchain/langchain/vectorstores/elasticsearch.py +++ b/libs/langchain/langchain/vectorstores/elasticsearch.py @@ -117,10 +117,16 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): self, query_model_id: Optional[str] = None, hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, ): self.query_model_id = query_model_id self.hybrid = hybrid + # RRF has two optional parameters + # 'rank_constant', 'window_size' + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + self.rrf = rrf + def query( self, query_vector: Union[List[float], None], @@ -161,8 +167,10 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): # If hybrid, add a query to the knn query # RRF is used to even the score from the knn query and text query + # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html if self.hybrid: - return { + query_body = { "knn": knn, "query": { "bool": { @@ -178,8 +186,14 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): "filter": filter, } }, - "rank": {"rrf": {}}, } + + if isinstance(self.rrf, dict): + query_body["rank"] = {"rrf": self.rrf} + elif isinstance(self.rrf, bool) and self.rrf is True: + query_body["rank"] = {"rrf": {}} + + return query_body else: return {"knn": knn} @@ -587,6 +601,7 @@ class ElasticsearchStore(VectorStore): self, query: str, k: int = 4, + fetch_k: int = 50, filter: Optional[List[dict]] = None, **kwargs: Any, ) -> List[Document]: @@ -595,6 +610,7 @@ class ElasticsearchStore(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to knn num_candidates. filter: Array of Elasticsearch filter clauses to apply to the query. Returns: @@ -602,7 +618,9 @@ class ElasticsearchStore(VectorStore): in descending order of similarity. """ - results = self._search(query=query, k=k, filter=filter, **kwargs) + results = self._search( + query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs + ) return [doc for doc, _ in results] def max_marginal_relevance_search( @@ -1187,6 +1205,7 @@ class ElasticsearchStore(VectorStore): def ApproxRetrievalStrategy( query_model_id: Optional[str] = None, hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, ) -> "ApproxRetrievalStrategy": """Used to perform approximate nearest neighbor search using the HNSW algorithm. @@ -1209,8 +1228,16 @@ class ElasticsearchStore(VectorStore): hybrid: Optional. If True, will perform a hybrid search using both the knn query and a text query. Defaults to False. + rrf: Optional. rrf is Reciprocal Rank Fusion. + When `hybrid` is True, + and `rrf` is True, then rrf: {}. + and `rrf` is False, then rrf is omitted. + and isinstance(rrf, dict) is True, then pass in the dict values. + rrf could be passed for adjusting 'rank_constant' and 'window_size'. """ - return ApproxRetrievalStrategy(query_model_id=query_model_id, hybrid=hybrid) + return ApproxRetrievalStrategy( + query_model_id=query_model_id, hybrid=hybrid, rrf=rrf + ) @staticmethod def SparseVectorRetrievalStrategy( diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py index d0a0d7f546e..8601e8a5cb2 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py @@ -481,6 +481,115 @@ class TestElasticsearch: output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + def test_similarity_search_approx_with_hybrid_search_rrf( + self, es_client: Any, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + from functools import partial + from typing import Optional + + # 1. check query_body is okay + rrf_test_cases: List[Optional[Union[dict, bool]]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + texts = ["foo", "bar", "baz"] + docsearch = ElasticsearchStore.from_texts( + texts, + FakeEmbeddings(), + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.ApproxRetrievalStrategy( + hybrid=True, rrf=rrf_test_case + ), + ) + + def assert_query( + query_body: dict, + query: str, + rrf: Optional[Union[dict, bool]] = True, + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + } + + if isinstance(rrf, dict): + cmp_query_body["rank"] = {"rrf": rrf} + elif isinstance(rrf, bool) and rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + ## without fetch_k parameter + output = docsearch.similarity_search( + "foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case) + ) + + # 2. check query result is okay + es_output = es_client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o.page_content for o in output] == [ + e["_source"]["text"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + docsearch = ElasticsearchStore.from_texts( + texts, + FakeEmbeddings(), + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True), + ) + + ## with fetch_k parameter + output = docsearch.similarity_search( + "foo", k=3, fetch_k=50, custom_query=assert_query + ) + def test_similarity_search_approx_with_custom_query_fn( self, elasticsearch_connection: dict, index_name: str ) -> None: