From 3d23a5eb36045db3b7a05c34947b74bd4909ba3b Mon Sep 17 00:00:00 2001 From: Ryan French Date: Sat, 20 Jan 2024 01:57:18 +0000 Subject: [PATCH] langchain[patch]: Allow OpenSearch Query Translator to correctly work with Date types (#16022) **Description:** Fixes an issue where the Date type in an OpenSearch Self Querying Retriever would fail to generate a valid query **Issue:** https://github.com/langchain-ai/langchain/issues/14225 --- .../retrievers/self_query/opensearch.py | 28 +++++- .../retrievers/self_query/test_opensearch.py | 87 +++++++++++++++++++ 2 files changed, 111 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/retrievers/self_query/opensearch.py b/libs/langchain/langchain/retrievers/self_query/opensearch.py index 502ee8a872b..bb27cddd0b4 100644 --- a/libs/langchain/langchain/retrievers/self_query/opensearch.py +++ b/libs/langchain/langchain/retrievers/self_query/opensearch.py @@ -58,11 +58,25 @@ class OpenSearchTranslator(Visitor): Comparator.GT, Comparator.GTE, ]: - return { - "range": { - field: {self._format_func(comparison.comparator): comparison.value} + if isinstance(comparison.value, dict): + if "date" in comparison.value: + return { + "range": { + field: { + self._format_func( + comparison.comparator + ): comparison.value["date"] + } + } + } + else: + return { + "range": { + field: { + self._format_func(comparison.comparator): comparison.value + } + } } - } if comparison.comparator == Comparator.LIKE: return { @@ -70,8 +84,13 @@ class OpenSearchTranslator(Visitor): field: {"value": comparison.value} } } + field = f"{field}.keyword" if isinstance(comparison.value, str) else field + if isinstance(comparison.value, dict): + if "date" in comparison.value: + comparison.value = comparison.value["date"] + return {self._format_func(comparison.comparator): {field: comparison.value}} def visit_structured_query( @@ -81,4 +100,5 @@ class OpenSearchTranslator(Visitor): kwargs = {} else: kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py index 629d195402e..93b6629ecb0 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py @@ -85,3 +85,90 @@ def test_visit_structured_query() -> None: ) actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) assert expected == actual + + +def test_visit_structured_query_with_date_range() -> None: + query = "Who was the president of France in 1995?" + operation = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.EQ, attribute="foo", value="20"), + Operation( + operator=Operator.AND, + arguments=[ + Comparison( + comparator=Comparator.GTE, + attribute="timestamp", + value={"date": "1995-01-01", "type": "date"}, + ), + Comparison( + comparator=Comparator.LT, + attribute="timestamp", + value={"date": "1996-01-01", "type": "date"}, + ), + ], + ), + ], + ) + structured_query = StructuredQuery(query=query, filter=operation, limit=None) + expected = ( + query, + { + "filter": { + "bool": { + "must": [ + {"term": {"metadata.foo.keyword": "20"}}, + { + "bool": { + "must": [ + { + "range": { + "metadata.timestamp": {"gte": "1995-01-01"} + } + }, + { + "range": { + "metadata.timestamp": {"lt": "1996-01-01"} + } + }, + ] + } + }, + ] + } + } + }, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + +def test_visit_structured_query_with_date() -> None: + query = "Who was the president of France on 1st of January 1995?" + operation = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.EQ, attribute="foo", value="20"), + Comparison( + comparator=Comparator.EQ, + attribute="timestamp", + value={"date": "1995-01-01", "type": "date"}, + ), + ], + ) + structured_query = StructuredQuery(query=query, filter=operation, limit=None) + expected = ( + query, + { + "filter": { + "bool": { + "must": [ + {"term": {"metadata.foo.keyword": "20"}}, + {"term": {"metadata.timestamp": "1995-01-01"}}, + ] + } + } + }, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual