mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 22:15:08 +00:00
Additional Weaviate Filter Comparators (#10522)
### Description When using Weaviate Self-Retrievers, certain common filter comparators generated by user queries were unimplemented, resulting in errors. This PR implements some of them. All linting and format commands have been run and tests passed. ### Issue #10474 ### Dependencies timestamp module --------- Co-authored-by: Patrick Randell <prandell@deloitte.com.au>
This commit is contained in:
parent
79011f835f
commit
1d678f805f
@ -1,3 +1,4 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
from langchain.chains.query_constructor.ir import (
|
from langchain.chains.query_constructor.ir import (
|
||||||
@ -16,12 +17,28 @@ class WeaviateTranslator(Visitor):
|
|||||||
allowed_operators = [Operator.AND, Operator.OR]
|
allowed_operators = [Operator.AND, Operator.OR]
|
||||||
"""Subset of allowed logical operators."""
|
"""Subset of allowed logical operators."""
|
||||||
|
|
||||||
allowed_comparators = [Comparator.EQ]
|
allowed_comparators = [
|
||||||
|
Comparator.EQ,
|
||||||
|
Comparator.NE,
|
||||||
|
Comparator.GTE,
|
||||||
|
Comparator.LTE,
|
||||||
|
Comparator.LT,
|
||||||
|
Comparator.GT,
|
||||||
|
]
|
||||||
|
|
||||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||||
self._validate_func(func)
|
self._validate_func(func)
|
||||||
# https://weaviate.io/developers/weaviate/api/graphql/filters
|
# https://weaviate.io/developers/weaviate/api/graphql/filters
|
||||||
map_dict = {Operator.AND: "And", Operator.OR: "Or", Comparator.EQ: "Equal"}
|
map_dict = {
|
||||||
|
Operator.AND: "And",
|
||||||
|
Operator.OR: "Or",
|
||||||
|
Comparator.EQ: "Equal",
|
||||||
|
Comparator.NE: "NotEqual",
|
||||||
|
Comparator.GTE: "GreaterThanEqual",
|
||||||
|
Comparator.LTE: "LessThanEqual",
|
||||||
|
Comparator.LT: "LessThan",
|
||||||
|
Comparator.GT: "GreaterThan",
|
||||||
|
}
|
||||||
return map_dict[func]
|
return map_dict[func]
|
||||||
|
|
||||||
def visit_operation(self, operation: Operation) -> Dict:
|
def visit_operation(self, operation: Operation) -> Dict:
|
||||||
@ -29,11 +46,25 @@ class WeaviateTranslator(Visitor):
|
|||||||
return {"operator": self._format_func(operation.operator), "operands": args}
|
return {"operator": self._format_func(operation.operator), "operands": args}
|
||||||
|
|
||||||
def visit_comparison(self, comparison: Comparison) -> Dict:
|
def visit_comparison(self, comparison: Comparison) -> Dict:
|
||||||
return {
|
value_type = "valueText"
|
||||||
|
if isinstance(comparison.value, bool):
|
||||||
|
value_type = "valueBoolean"
|
||||||
|
elif isinstance(comparison.value, float):
|
||||||
|
value_type = "valueNumber"
|
||||||
|
elif isinstance(comparison.value, int):
|
||||||
|
value_type = "valueInt"
|
||||||
|
elif isinstance(comparison.value, datetime) or isinstance(
|
||||||
|
comparison.value, date
|
||||||
|
):
|
||||||
|
value_type = "valueDate"
|
||||||
|
# ISO 8601 timestamp, formatted as RFC3339
|
||||||
|
comparison.value = comparison.value.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
filter = {
|
||||||
"path": [comparison.attribute],
|
"path": [comparison.attribute],
|
||||||
"operator": self._format_func(comparison.comparator),
|
"operator": self._format_func(comparison.comparator),
|
||||||
"valueText": comparison.value,
|
|
||||||
}
|
}
|
||||||
|
filter[value_type] = comparison.value
|
||||||
|
return filter
|
||||||
|
|
||||||
def visit_structured_query(
|
def visit_structured_query(
|
||||||
self, structured_query: StructuredQuery
|
self, structured_query: StructuredQuery
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
from langchain.chains.query_constructor.ir import (
|
from langchain.chains.query_constructor.ir import (
|
||||||
@ -19,18 +20,75 @@ def test_visit_comparison() -> None:
|
|||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison_integer() -> None:
|
||||||
|
comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=1)
|
||||||
|
expected = {"operator": "GreaterThanEqual", "path": ["foo"], "valueInt": 1}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison_number() -> None:
|
||||||
|
comp = Comparison(comparator=Comparator.GT, attribute="foo", value=1.4)
|
||||||
|
expected = {"operator": "GreaterThan", "path": ["foo"], "valueNumber": 1.4}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison_boolean() -> None:
|
||||||
|
comp = Comparison(comparator=Comparator.NE, attribute="foo", value=False)
|
||||||
|
expected = {"operator": "NotEqual", "path": ["foo"], "valueBoolean": False}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison_datetime() -> None:
|
||||||
|
comp = Comparison(
|
||||||
|
comparator=Comparator.LTE,
|
||||||
|
attribute="foo",
|
||||||
|
value=datetime(2023, 9, 13, 4, 20, 0),
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"operator": "LessThanEqual",
|
||||||
|
"path": ["foo"],
|
||||||
|
"valueDate": "2023-09-13T04:20:00Z",
|
||||||
|
}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison_date() -> None:
|
||||||
|
comp = Comparison(
|
||||||
|
comparator=Comparator.LT, attribute="foo", value=date(2023, 9, 13)
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"operator": "LessThan",
|
||||||
|
"path": ["foo"],
|
||||||
|
"valueDate": "2023-09-13T00:00:00Z",
|
||||||
|
}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
def test_visit_operation() -> None:
|
def test_visit_operation() -> None:
|
||||||
op = Operation(
|
op = Operation(
|
||||||
operator=Operator.AND,
|
operator=Operator.AND,
|
||||||
arguments=[
|
arguments=[
|
||||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
Comparison(comparator=Comparator.EQ, attribute="foo", value="hello"),
|
||||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
Comparison(
|
||||||
|
comparator=Comparator.GTE, attribute="bar", value=date(2023, 9, 13)
|
||||||
|
),
|
||||||
|
Comparison(comparator=Comparator.LTE, attribute="abc", value=1.4),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
expected = {
|
expected = {
|
||||||
"operands": [
|
"operands": [
|
||||||
{"operator": "Equal", "path": ["foo"], "valueText": 2},
|
{"operator": "Equal", "path": ["foo"], "valueText": "hello"},
|
||||||
{"operator": "Equal", "path": ["bar"], "valueText": "baz"},
|
{
|
||||||
|
"operator": "GreaterThanEqual",
|
||||||
|
"path": ["bar"],
|
||||||
|
"valueDate": "2023-09-13T00:00:00Z",
|
||||||
|
},
|
||||||
|
{"operator": "LessThanEqual", "path": ["abc"], "valueNumber": 1.4},
|
||||||
],
|
],
|
||||||
"operator": "And",
|
"operator": "And",
|
||||||
}
|
}
|
||||||
@ -78,7 +136,7 @@ def test_visit_structured_query() -> None:
|
|||||||
"where_filter": {
|
"where_filter": {
|
||||||
"operator": "And",
|
"operator": "And",
|
||||||
"operands": [
|
"operands": [
|
||||||
{"path": ["foo"], "operator": "Equal", "valueText": 2},
|
{"path": ["foo"], "operator": "Equal", "valueInt": 2},
|
||||||
{"path": ["bar"], "operator": "Equal", "valueText": "baz"},
|
{"path": ["bar"], "operator": "Equal", "valueText": "baz"},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user