mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
added rrf argument in ApproxRetrievalStrategy class __init__() (#11987)
- **Description: To handle the hybrid search with RRF(Reciprocal Rank Fusion) in the Elasticsearch, rrf argument was added for adjusting 'rank_constant' and 'window_size' to combine multiple result sets with different relevance indicators into a single result set. (ref: https://www.elastic.co/kr/blog/whats-new-elastic-enterprise-search-8-9-0), - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** No dependencies changed, - **Tag maintainer:** @baskaryan, Nice to meet you, I'm a newbie for contributions and it's my first PR. I only changed the langchain/vectorstores/elasticsearch.py file. I did make format&lint I got this message, ```shell make lint_diff ./scripts/check_pydantic.sh . ./scripts/check_imports.sh poetry run ruff . [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run black langchain/vectorstores/elasticsearch.py --check All done! ✨ 🍰 ✨ 1 file would be left unchanged. [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run mypy langchain/vectorstores/elasticsearch.py langchain/__init__.py: error: Source file found twice under different module names: "mvp.nlp.langchain.libs.langchain.langchain" and "langchain" Found 1 error in 1 file (errors prevented further checking) make: *** [lint_diff] Error 2 ``` Thank you --------- Co-authored-by: 황중원 <jwhwang@amorepacific.com>
This commit is contained in:
parent
2c58dca5f0
commit
d38c8369b3
@ -117,10 +117,16 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
|||||||
self,
|
self,
|
||||||
query_model_id: Optional[str] = None,
|
query_model_id: Optional[str] = None,
|
||||||
hybrid: Optional[bool] = False,
|
hybrid: Optional[bool] = False,
|
||||||
|
rrf: Optional[Union[dict, bool]] = True,
|
||||||
):
|
):
|
||||||
self.query_model_id = query_model_id
|
self.query_model_id = query_model_id
|
||||||
self.hybrid = hybrid
|
self.hybrid = hybrid
|
||||||
|
|
||||||
|
# RRF has two optional parameters
|
||||||
|
# 'rank_constant', 'window_size'
|
||||||
|
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
|
||||||
|
self.rrf = rrf
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
query_vector: Union[List[float], None],
|
query_vector: Union[List[float], None],
|
||||||
@ -161,8 +167,10 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
|||||||
|
|
||||||
# If hybrid, add a query to the knn query
|
# If hybrid, add a query to the knn query
|
||||||
# RRF is used to even the score from the knn query and text query
|
# RRF is used to even the score from the knn query and text query
|
||||||
|
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
|
||||||
|
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
|
||||||
if self.hybrid:
|
if self.hybrid:
|
||||||
return {
|
query_body = {
|
||||||
"knn": knn,
|
"knn": knn,
|
||||||
"query": {
|
"query": {
|
||||||
"bool": {
|
"bool": {
|
||||||
@ -178,8 +186,14 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
|||||||
"filter": filter,
|
"filter": filter,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"rank": {"rrf": {}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isinstance(self.rrf, dict):
|
||||||
|
query_body["rank"] = {"rrf": self.rrf}
|
||||||
|
elif isinstance(self.rrf, bool) and self.rrf is True:
|
||||||
|
query_body["rank"] = {"rrf": {}}
|
||||||
|
|
||||||
|
return query_body
|
||||||
else:
|
else:
|
||||||
return {"knn": knn}
|
return {"knn": knn}
|
||||||
|
|
||||||
@ -587,6 +601,7 @@ class ElasticsearchStore(VectorStore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
fetch_k: int = 50,
|
||||||
filter: Optional[List[dict]] = None,
|
filter: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -595,6 +610,7 @@ class ElasticsearchStore(VectorStore):
|
|||||||
Args:
|
Args:
|
||||||
query: Text to look up documents similar to.
|
query: Text to look up documents similar to.
|
||||||
k: Number of Documents to return. Defaults to 4.
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k (int): Number of Documents to fetch to pass to knn num_candidates.
|
||||||
filter: Array of Elasticsearch filter clauses to apply to the query.
|
filter: Array of Elasticsearch filter clauses to apply to the query.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -602,7 +618,9 @@ class ElasticsearchStore(VectorStore):
|
|||||||
in descending order of similarity.
|
in descending order of similarity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = self._search(query=query, k=k, filter=filter, **kwargs)
|
results = self._search(
|
||||||
|
query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
|
||||||
|
)
|
||||||
return [doc for doc, _ in results]
|
return [doc for doc, _ in results]
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
@ -1187,6 +1205,7 @@ class ElasticsearchStore(VectorStore):
|
|||||||
def ApproxRetrievalStrategy(
|
def ApproxRetrievalStrategy(
|
||||||
query_model_id: Optional[str] = None,
|
query_model_id: Optional[str] = None,
|
||||||
hybrid: Optional[bool] = False,
|
hybrid: Optional[bool] = False,
|
||||||
|
rrf: Optional[Union[dict, bool]] = True,
|
||||||
) -> "ApproxRetrievalStrategy":
|
) -> "ApproxRetrievalStrategy":
|
||||||
"""Used to perform approximate nearest neighbor search
|
"""Used to perform approximate nearest neighbor search
|
||||||
using the HNSW algorithm.
|
using the HNSW algorithm.
|
||||||
@ -1209,8 +1228,16 @@ class ElasticsearchStore(VectorStore):
|
|||||||
hybrid: Optional. If True, will perform a hybrid search
|
hybrid: Optional. If True, will perform a hybrid search
|
||||||
using both the knn query and a text query.
|
using both the knn query and a text query.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
rrf: Optional. rrf is Reciprocal Rank Fusion.
|
||||||
|
When `hybrid` is True,
|
||||||
|
and `rrf` is True, then rrf: {}.
|
||||||
|
and `rrf` is False, then rrf is omitted.
|
||||||
|
and isinstance(rrf, dict) is True, then pass in the dict values.
|
||||||
|
rrf could be passed for adjusting 'rank_constant' and 'window_size'.
|
||||||
"""
|
"""
|
||||||
return ApproxRetrievalStrategy(query_model_id=query_model_id, hybrid=hybrid)
|
return ApproxRetrievalStrategy(
|
||||||
|
query_model_id=query_model_id, hybrid=hybrid, rrf=rrf
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def SparseVectorRetrievalStrategy(
|
def SparseVectorRetrievalStrategy(
|
||||||
|
@ -481,6 +481,115 @@ class TestElasticsearch:
|
|||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_similarity_search_approx_with_hybrid_search_rrf(
|
||||||
|
self, es_client: Any, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and rrf hybrid search with metadata."""
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# 1. check query_body is okay
|
||||||
|
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
{"rank_constant": 1, "window_size": 5},
|
||||||
|
]
|
||||||
|
for rrf_test_case in rrf_test_cases:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||||
|
hybrid=True, rrf=rrf_test_case
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(
|
||||||
|
query_body: dict,
|
||||||
|
query: str,
|
||||||
|
rrf: Optional[Union[dict, bool]] = True,
|
||||||
|
) -> dict:
|
||||||
|
cmp_query_body = {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"filter": [],
|
||||||
|
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(rrf, dict):
|
||||||
|
cmp_query_body["rank"] = {"rrf": rrf}
|
||||||
|
elif isinstance(rrf, bool) and rrf is True:
|
||||||
|
cmp_query_body["rank"] = {"rrf": {}}
|
||||||
|
|
||||||
|
assert query_body == cmp_query_body
|
||||||
|
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
## without fetch_k parameter
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. check query result is okay
|
||||||
|
es_output = es_client.search(
|
||||||
|
index=index_name,
|
||||||
|
query={
|
||||||
|
"bool": {
|
||||||
|
"filter": [],
|
||||||
|
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
knn={
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
},
|
||||||
|
size=3,
|
||||||
|
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [o.page_content for o in output] == [
|
||||||
|
e["_source"]["text"] for e in es_output["hits"]["hits"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. check rrf default option is okay
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
## with fetch_k parameter
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=3, fetch_k=50, custom_query=assert_query
|
||||||
|
)
|
||||||
|
|
||||||
def test_similarity_search_approx_with_custom_query_fn(
|
def test_similarity_search_approx_with_custom_query_fn(
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user