diff --git a/libs/langchain/langchain/retrievers/self_query/weaviate.py b/libs/langchain/langchain/retrievers/self_query/weaviate.py index 13ab09891c1..5b8652db595 100644 --- a/libs/langchain/langchain/retrievers/self_query/weaviate.py +++ b/libs/langchain/langchain/retrievers/self_query/weaviate.py @@ -1,3 +1,4 @@ +from datetime import date, datetime from typing import Dict, Tuple, Union from langchain.chains.query_constructor.ir import ( @@ -16,12 +17,28 @@ class WeaviateTranslator(Visitor): allowed_operators = [Operator.AND, Operator.OR] """Subset of allowed logical operators.""" - allowed_comparators = [Comparator.EQ] + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GTE, + Comparator.LTE, + Comparator.LT, + Comparator.GT, + ] def _format_func(self, func: Union[Operator, Comparator]) -> str: self._validate_func(func) # https://weaviate.io/developers/weaviate/api/graphql/filters - map_dict = {Operator.AND: "And", Operator.OR: "Or", Comparator.EQ: "Equal"} + map_dict = { + Operator.AND: "And", + Operator.OR: "Or", + Comparator.EQ: "Equal", + Comparator.NE: "NotEqual", + Comparator.GTE: "GreaterThanEqual", + Comparator.LTE: "LessThanEqual", + Comparator.LT: "LessThan", + Comparator.GT: "GreaterThan", + } return map_dict[func] def visit_operation(self, operation: Operation) -> Dict: @@ -29,11 +46,25 @@ class WeaviateTranslator(Visitor): return {"operator": self._format_func(operation.operator), "operands": args} def visit_comparison(self, comparison: Comparison) -> Dict: - return { + value_type = "valueText" + if isinstance(comparison.value, bool): + value_type = "valueBoolean" + elif isinstance(comparison.value, float): + value_type = "valueNumber" + elif isinstance(comparison.value, int): + value_type = "valueInt" + elif isinstance(comparison.value, datetime) or isinstance( + comparison.value, date + ): + value_type = "valueDate" + # ISO 8601 timestamp, formatted as RFC3339 + comparison.value = comparison.value.strftime("%Y-%m-%dT%H:%M:%SZ") + filter = { "path": [comparison.attribute], "operator": self._format_func(comparison.comparator), - "valueText": comparison.value, } + filter[value_type] = comparison.value + return filter def visit_structured_query( self, structured_query: StructuredQuery diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py index 0a7385af455..1ba2985f44a 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py @@ -1,3 +1,4 @@ +from datetime import date, datetime from typing import Dict, Tuple from langchain.chains.query_constructor.ir import ( @@ -19,18 +20,75 @@ def test_visit_comparison() -> None: assert expected == actual +def test_visit_comparison_integer() -> None: + comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=1) + expected = {"operator": "GreaterThanEqual", "path": ["foo"], "valueInt": 1} + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_comparison_number() -> None: + comp = Comparison(comparator=Comparator.GT, attribute="foo", value=1.4) + expected = {"operator": "GreaterThan", "path": ["foo"], "valueNumber": 1.4} + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_comparison_boolean() -> None: + comp = Comparison(comparator=Comparator.NE, attribute="foo", value=False) + expected = {"operator": "NotEqual", "path": ["foo"], "valueBoolean": False} + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_comparison_datetime() -> None: + comp = Comparison( + comparator=Comparator.LTE, + attribute="foo", + value=datetime(2023, 9, 13, 4, 20, 0), + ) + expected = { + "operator": "LessThanEqual", + "path": ["foo"], + "valueDate": "2023-09-13T04:20:00Z", + } + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_comparison_date() -> None: + comp = Comparison( + comparator=Comparator.LT, attribute="foo", value=date(2023, 9, 13) + ) + expected = { + "operator": "LessThan", + "path": ["foo"], + "valueDate": "2023-09-13T00:00:00Z", + } + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + def test_visit_operation() -> None: op = Operation( operator=Operator.AND, arguments=[ - Comparison(comparator=Comparator.EQ, attribute="foo", value=2), - Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + Comparison(comparator=Comparator.EQ, attribute="foo", value="hello"), + Comparison( + comparator=Comparator.GTE, attribute="bar", value=date(2023, 9, 13) + ), + Comparison(comparator=Comparator.LTE, attribute="abc", value=1.4), ], ) expected = { "operands": [ - {"operator": "Equal", "path": ["foo"], "valueText": 2}, - {"operator": "Equal", "path": ["bar"], "valueText": "baz"}, + {"operator": "Equal", "path": ["foo"], "valueText": "hello"}, + { + "operator": "GreaterThanEqual", + "path": ["bar"], + "valueDate": "2023-09-13T00:00:00Z", + }, + {"operator": "LessThanEqual", "path": ["abc"], "valueNumber": 1.4}, ], "operator": "And", } @@ -78,7 +136,7 @@ def test_visit_structured_query() -> None: "where_filter": { "operator": "And", "operands": [ - {"path": ["foo"], "operator": "Equal", "valueText": 2}, + {"path": ["foo"], "operator": "Equal", "valueInt": 2}, {"path": ["bar"], "operator": "Equal", "valueText": "baz"}, ], }