From d38c8369b3a131690c34eb13bf8c454aef15e224 Mon Sep 17 00:00:00 2001 From: HwangJohn Date: Sat, 28 Oct 2023 03:53:19 +0900 Subject: [PATCH] added rrf argument in ApproxRetrievalStrategy class __init__() (#11987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description: To handle the hybrid search with RRF(Reciprocal Rank Fusion) in the Elasticsearch, rrf argument was added for adjusting 'rank_constant' and 'window_size' to combine multiple result sets with different relevance indicators into a single result set. (ref: https://www.elastic.co/kr/blog/whats-new-elastic-enterprise-search-8-9-0), - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** No dependencies changed, - **Tag maintainer:** @baskaryan, Nice to meet you, I'm a newbie for contributions and it's my first PR. I only changed the langchain/vectorstores/elasticsearch.py file. I did make format&lint I got this message, ```shell make lint_diff ./scripts/check_pydantic.sh . ./scripts/check_imports.sh poetry run ruff . [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run black langchain/vectorstores/elasticsearch.py --check All done! ✨ 🍰 ✨ 1 file would be left unchanged. [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run mypy langchain/vectorstores/elasticsearch.py langchain/__init__.py: error: Source file found twice under different module names: "mvp.nlp.langchain.libs.langchain.langchain" and "langchain" Found 1 error in 1 file (errors prevented further checking) make: *** [lint_diff] Error 2 ``` Thank you --------- Co-authored-by: í™Šė¤‘ė› --- .../langchain/vectorstores/elasticsearch.py | 35 +++++- .../vectorstores/test_elasticsearch.py | 109 ++++++++++++++++++ 2 files changed, 140 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/elasticsearch.py b/libs/langchain/langchain/vectorstores/elasticsearch.py index 4bf2d4c47e5..3210cf4f9ea 100644 --- a/libs/langchain/langchain/vectorstores/elasticsearch.py +++ b/libs/langchain/langchain/vectorstores/elasticsearch.py @@ -117,10 +117,16 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): self, query_model_id: Optional[str] = None, hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, ): self.query_model_id = query_model_id self.hybrid = hybrid + # RRF has two optional parameters + # 'rank_constant', 'window_size' + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + self.rrf = rrf + def query( self, query_vector: Union[List[float], None], @@ -161,8 +167,10 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): # If hybrid, add a query to the knn query # RRF is used to even the score from the knn query and text query + # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html if self.hybrid: - return { + query_body = { "knn": knn, "query": { "bool": { @@ -178,8 +186,14 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy): "filter": filter, } }, - "rank": {"rrf": {}}, } + + if isinstance(self.rrf, dict): + query_body["rank"] = {"rrf": self.rrf} + elif isinstance(self.rrf, bool) and self.rrf is True: + query_body["rank"] = {"rrf": {}} + + return query_body else: return {"knn": knn} @@ -587,6 +601,7 @@ class ElasticsearchStore(VectorStore): self, query: str, k: int = 4, + fetch_k: int = 50, filter: Optional[List[dict]] = None, **kwargs: Any, ) -> List[Document]: @@ -595,6 +610,7 @@ class ElasticsearchStore(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to knn num_candidates. filter: Array of Elasticsearch filter clauses to apply to the query. Returns: @@ -602,7 +618,9 @@ class ElasticsearchStore(VectorStore): in descending order of similarity. """ - results = self._search(query=query, k=k, filter=filter, **kwargs) + results = self._search( + query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs + ) return [doc for doc, _ in results] def max_marginal_relevance_search( @@ -1187,6 +1205,7 @@ class ElasticsearchStore(VectorStore): def ApproxRetrievalStrategy( query_model_id: Optional[str] = None, hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, ) -> "ApproxRetrievalStrategy": """Used to perform approximate nearest neighbor search using the HNSW algorithm. @@ -1209,8 +1228,16 @@ class ElasticsearchStore(VectorStore): hybrid: Optional. If True, will perform a hybrid search using both the knn query and a text query. Defaults to False. + rrf: Optional. rrf is Reciprocal Rank Fusion. + When `hybrid` is True, + and `rrf` is True, then rrf: {}. + and `rrf` is False, then rrf is omitted. + and isinstance(rrf, dict) is True, then pass in the dict values. + rrf could be passed for adjusting 'rank_constant' and 'window_size'. """ - return ApproxRetrievalStrategy(query_model_id=query_model_id, hybrid=hybrid) + return ApproxRetrievalStrategy( + query_model_id=query_model_id, hybrid=hybrid, rrf=rrf + ) @staticmethod def SparseVectorRetrievalStrategy( diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py index d0a0d7f546e..8601e8a5cb2 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py @@ -481,6 +481,115 @@ class TestElasticsearch: output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + def test_similarity_search_approx_with_hybrid_search_rrf( + self, es_client: Any, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + from functools import partial + from typing import Optional + + # 1. check query_body is okay + rrf_test_cases: List[Optional[Union[dict, bool]]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + texts = ["foo", "bar", "baz"] + docsearch = ElasticsearchStore.from_texts( + texts, + FakeEmbeddings(), + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.ApproxRetrievalStrategy( + hybrid=True, rrf=rrf_test_case + ), + ) + + def assert_query( + query_body: dict, + query: str, + rrf: Optional[Union[dict, bool]] = True, + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + } + + if isinstance(rrf, dict): + cmp_query_body["rank"] = {"rrf": rrf} + elif isinstance(rrf, bool) and rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + ## without fetch_k parameter + output = docsearch.similarity_search( + "foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case) + ) + + # 2. check query result is okay + es_output = es_client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o.page_content for o in output] == [ + e["_source"]["text"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + docsearch = ElasticsearchStore.from_texts( + texts, + FakeEmbeddings(), + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True), + ) + + ## with fetch_k parameter + output = docsearch.similarity_search( + "foo", k=3, fetch_k=50, custom_query=assert_query + ) + def test_similarity_search_approx_with_custom_query_fn( self, elasticsearch_connection: dict, index_name: str ) -> None: