mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
added rrf argument in ApproxRetrievalStrategy class __init__() (#11987)
- **Description: To handle the hybrid search with RRF(Reciprocal Rank Fusion) in the Elasticsearch, rrf argument was added for adjusting 'rank_constant' and 'window_size' to combine multiple result sets with different relevance indicators into a single result set. (ref: https://www.elastic.co/kr/blog/whats-new-elastic-enterprise-search-8-9-0), - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** No dependencies changed, - **Tag maintainer:** @baskaryan, Nice to meet you, I'm a newbie for contributions and it's my first PR. I only changed the langchain/vectorstores/elasticsearch.py file. I did make format&lint I got this message, ```shell make lint_diff ./scripts/check_pydantic.sh . ./scripts/check_imports.sh poetry run ruff . [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run black langchain/vectorstores/elasticsearch.py --check All done! ✨ 🍰 ✨ 1 file would be left unchanged. [ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run mypy langchain/vectorstores/elasticsearch.py langchain/__init__.py: error: Source file found twice under different module names: "mvp.nlp.langchain.libs.langchain.langchain" and "langchain" Found 1 error in 1 file (errors prevented further checking) make: *** [lint_diff] Error 2 ``` Thank you --------- Co-authored-by: 황중원 <jwhwang@amorepacific.com>
This commit is contained in:
parent
2c58dca5f0
commit
d38c8369b3
@ -117,10 +117,16 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
||||
self,
|
||||
query_model_id: Optional[str] = None,
|
||||
hybrid: Optional[bool] = False,
|
||||
rrf: Optional[Union[dict, bool]] = True,
|
||||
):
|
||||
self.query_model_id = query_model_id
|
||||
self.hybrid = hybrid
|
||||
|
||||
# RRF has two optional parameters
|
||||
# 'rank_constant', 'window_size'
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
|
||||
self.rrf = rrf
|
||||
|
||||
def query(
|
||||
self,
|
||||
query_vector: Union[List[float], None],
|
||||
@ -161,8 +167,10 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
||||
|
||||
# If hybrid, add a query to the knn query
|
||||
# RRF is used to even the score from the knn query and text query
|
||||
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
|
||||
if self.hybrid:
|
||||
return {
|
||||
query_body = {
|
||||
"knn": knn,
|
||||
"query": {
|
||||
"bool": {
|
||||
@ -178,8 +186,14 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
||||
"filter": filter,
|
||||
}
|
||||
},
|
||||
"rank": {"rrf": {}},
|
||||
}
|
||||
|
||||
if isinstance(self.rrf, dict):
|
||||
query_body["rank"] = {"rrf": self.rrf}
|
||||
elif isinstance(self.rrf, bool) and self.rrf is True:
|
||||
query_body["rank"] = {"rrf": {}}
|
||||
|
||||
return query_body
|
||||
else:
|
||||
return {"knn": knn}
|
||||
|
||||
@ -587,6 +601,7 @@ class ElasticsearchStore(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 50,
|
||||
filter: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
@ -595,6 +610,7 @@ class ElasticsearchStore(VectorStore):
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to knn num_candidates.
|
||||
filter: Array of Elasticsearch filter clauses to apply to the query.
|
||||
|
||||
Returns:
|
||||
@ -602,7 +618,9 @@ class ElasticsearchStore(VectorStore):
|
||||
in descending order of similarity.
|
||||
"""
|
||||
|
||||
results = self._search(query=query, k=k, filter=filter, **kwargs)
|
||||
results = self._search(
|
||||
query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in results]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
@ -1187,6 +1205,7 @@ class ElasticsearchStore(VectorStore):
|
||||
def ApproxRetrievalStrategy(
|
||||
query_model_id: Optional[str] = None,
|
||||
hybrid: Optional[bool] = False,
|
||||
rrf: Optional[Union[dict, bool]] = True,
|
||||
) -> "ApproxRetrievalStrategy":
|
||||
"""Used to perform approximate nearest neighbor search
|
||||
using the HNSW algorithm.
|
||||
@ -1209,8 +1228,16 @@ class ElasticsearchStore(VectorStore):
|
||||
hybrid: Optional. If True, will perform a hybrid search
|
||||
using both the knn query and a text query.
|
||||
Defaults to False.
|
||||
rrf: Optional. rrf is Reciprocal Rank Fusion.
|
||||
When `hybrid` is True,
|
||||
and `rrf` is True, then rrf: {}.
|
||||
and `rrf` is False, then rrf is omitted.
|
||||
and isinstance(rrf, dict) is True, then pass in the dict values.
|
||||
rrf could be passed for adjusting 'rank_constant' and 'window_size'.
|
||||
"""
|
||||
return ApproxRetrievalStrategy(query_model_id=query_model_id, hybrid=hybrid)
|
||||
return ApproxRetrievalStrategy(
|
||||
query_model_id=query_model_id, hybrid=hybrid, rrf=rrf
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def SparseVectorRetrievalStrategy(
|
||||
|
@ -481,6 +481,115 @@ class TestElasticsearch:
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_approx_with_hybrid_search_rrf(
|
||||
self, es_client: Any, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and rrf hybrid search with metadata."""
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
# 1. check query_body is okay
|
||||
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
|
||||
True,
|
||||
False,
|
||||
{"rank_constant": 1, "window_size": 5},
|
||||
]
|
||||
for rrf_test_case in rrf_test_cases:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||
hybrid=True, rrf=rrf_test_case
|
||||
),
|
||||
)
|
||||
|
||||
def assert_query(
|
||||
query_body: dict,
|
||||
query: str,
|
||||
rrf: Optional[Union[dict, bool]] = True,
|
||||
) -> dict:
|
||||
cmp_query_body = {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 3,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
0.0,
|
||||
],
|
||||
},
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [],
|
||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(rrf, dict):
|
||||
cmp_query_body["rank"] = {"rrf": rrf}
|
||||
elif isinstance(rrf, bool) and rrf is True:
|
||||
cmp_query_body["rank"] = {"rrf": {}}
|
||||
|
||||
assert query_body == cmp_query_body
|
||||
|
||||
return query_body
|
||||
|
||||
## without fetch_k parameter
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
|
||||
)
|
||||
|
||||
# 2. check query result is okay
|
||||
es_output = es_client.search(
|
||||
index=index_name,
|
||||
query={
|
||||
"bool": {
|
||||
"filter": [],
|
||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||
}
|
||||
},
|
||||
knn={
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 3,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
},
|
||||
size=3,
|
||||
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
|
||||
)
|
||||
|
||||
assert [o.page_content for o in output] == [
|
||||
e["_source"]["text"] for e in es_output["hits"]["hits"]
|
||||
]
|
||||
|
||||
# 3. check rrf default option is okay
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||
)
|
||||
|
||||
## with fetch_k parameter
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=3, fetch_k=50, custom_query=assert_query
|
||||
)
|
||||
|
||||
def test_similarity_search_approx_with_custom_query_fn(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user