mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Additional Weaviate Filter Comparators (#10522)
### Description When using Weaviate Self-Retrievers, certain common filter comparators generated by user queries were unimplemented, resulting in errors. This PR implements some of them. All linting and format commands have been run and tests passed. ### Issue #10474 ### Dependencies timestamp module --------- Co-authored-by: Patrick Randell <prandell@deloitte.com.au>
This commit is contained in:
parent
79011f835f
commit
1d678f805f
@ -1,3 +1,4 @@
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
@ -16,12 +17,28 @@ class WeaviateTranslator(Visitor):
|
||||
allowed_operators = [Operator.AND, Operator.OR]
|
||||
"""Subset of allowed logical operators."""
|
||||
|
||||
allowed_comparators = [Comparator.EQ]
|
||||
allowed_comparators = [
|
||||
Comparator.EQ,
|
||||
Comparator.NE,
|
||||
Comparator.GTE,
|
||||
Comparator.LTE,
|
||||
Comparator.LT,
|
||||
Comparator.GT,
|
||||
]
|
||||
|
||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||
self._validate_func(func)
|
||||
# https://weaviate.io/developers/weaviate/api/graphql/filters
|
||||
map_dict = {Operator.AND: "And", Operator.OR: "Or", Comparator.EQ: "Equal"}
|
||||
map_dict = {
|
||||
Operator.AND: "And",
|
||||
Operator.OR: "Or",
|
||||
Comparator.EQ: "Equal",
|
||||
Comparator.NE: "NotEqual",
|
||||
Comparator.GTE: "GreaterThanEqual",
|
||||
Comparator.LTE: "LessThanEqual",
|
||||
Comparator.LT: "LessThan",
|
||||
Comparator.GT: "GreaterThan",
|
||||
}
|
||||
return map_dict[func]
|
||||
|
||||
def visit_operation(self, operation: Operation) -> Dict:
|
||||
@ -29,11 +46,25 @@ class WeaviateTranslator(Visitor):
|
||||
return {"operator": self._format_func(operation.operator), "operands": args}
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> Dict:
|
||||
return {
|
||||
value_type = "valueText"
|
||||
if isinstance(comparison.value, bool):
|
||||
value_type = "valueBoolean"
|
||||
elif isinstance(comparison.value, float):
|
||||
value_type = "valueNumber"
|
||||
elif isinstance(comparison.value, int):
|
||||
value_type = "valueInt"
|
||||
elif isinstance(comparison.value, datetime) or isinstance(
|
||||
comparison.value, date
|
||||
):
|
||||
value_type = "valueDate"
|
||||
# ISO 8601 timestamp, formatted as RFC3339
|
||||
comparison.value = comparison.value.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
filter = {
|
||||
"path": [comparison.attribute],
|
||||
"operator": self._format_func(comparison.comparator),
|
||||
"valueText": comparison.value,
|
||||
}
|
||||
filter[value_type] = comparison.value
|
||||
return filter
|
||||
|
||||
def visit_structured_query(
|
||||
self, structured_query: StructuredQuery
|
||||
|
@ -1,3 +1,4 @@
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
@ -19,18 +20,75 @@ def test_visit_comparison() -> None:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_integer() -> None:
|
||||
comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=1)
|
||||
expected = {"operator": "GreaterThanEqual", "path": ["foo"], "valueInt": 1}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_number() -> None:
|
||||
comp = Comparison(comparator=Comparator.GT, attribute="foo", value=1.4)
|
||||
expected = {"operator": "GreaterThan", "path": ["foo"], "valueNumber": 1.4}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_boolean() -> None:
|
||||
comp = Comparison(comparator=Comparator.NE, attribute="foo", value=False)
|
||||
expected = {"operator": "NotEqual", "path": ["foo"], "valueBoolean": False}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_datetime() -> None:
|
||||
comp = Comparison(
|
||||
comparator=Comparator.LTE,
|
||||
attribute="foo",
|
||||
value=datetime(2023, 9, 13, 4, 20, 0),
|
||||
)
|
||||
expected = {
|
||||
"operator": "LessThanEqual",
|
||||
"path": ["foo"],
|
||||
"valueDate": "2023-09-13T04:20:00Z",
|
||||
}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_date() -> None:
|
||||
comp = Comparison(
|
||||
comparator=Comparator.LT, attribute="foo", value=date(2023, 9, 13)
|
||||
)
|
||||
expected = {
|
||||
"operator": "LessThan",
|
||||
"path": ["foo"],
|
||||
"valueDate": "2023-09-13T00:00:00Z",
|
||||
}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value="hello"),
|
||||
Comparison(
|
||||
comparator=Comparator.GTE, attribute="bar", value=date(2023, 9, 13)
|
||||
),
|
||||
Comparison(comparator=Comparator.LTE, attribute="abc", value=1.4),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
"operands": [
|
||||
{"operator": "Equal", "path": ["foo"], "valueText": 2},
|
||||
{"operator": "Equal", "path": ["bar"], "valueText": "baz"},
|
||||
{"operator": "Equal", "path": ["foo"], "valueText": "hello"},
|
||||
{
|
||||
"operator": "GreaterThanEqual",
|
||||
"path": ["bar"],
|
||||
"valueDate": "2023-09-13T00:00:00Z",
|
||||
},
|
||||
{"operator": "LessThanEqual", "path": ["abc"], "valueNumber": 1.4},
|
||||
],
|
||||
"operator": "And",
|
||||
}
|
||||
@ -78,7 +136,7 @@ def test_visit_structured_query() -> None:
|
||||
"where_filter": {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
{"path": ["foo"], "operator": "Equal", "valueText": 2},
|
||||
{"path": ["foo"], "operator": "Equal", "valueInt": 2},
|
||||
{"path": ["bar"], "operator": "Equal", "valueText": "baz"},
|
||||
],
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user