Self-query template (#12694)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur
2023-11-13 11:44:19 -08:00
committed by GitHub
parent 1e43025bf5
commit 2e42ed5de6
16 changed files with 3240 additions and 21 deletions

View File

@@ -1,5 +1,7 @@
import datetime
from typing import Any, Optional, Sequence, Union
from typing import Any, Literal, Optional, Sequence, Union
from typing_extensions import TypedDict
from langchain.utils import check_package_version
@@ -32,14 +34,14 @@ GRAMMAR = r"""
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| TIMESTAMP -> timestamp
| DATE -> date
| list
| string
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
TIMESTAMP.2: /["'](\d{4}-[01]\d-[0-3]\d)["']/
DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
@@ -52,6 +54,11 @@ GRAMMAR = r"""
"""
class ISO8601Date(TypedDict):
date: str
type: Literal["date"]
@v_args(inline=True)
class QueryTransformer(Transformer):
"""Transforms a query string into an intermediate representation."""
@@ -129,9 +136,16 @@ class QueryTransformer(Transformer):
def float(self, item: Any) -> float:
return float(item)
def timestamp(self, item: Any) -> datetime.date:
item = item.replace("'", '"')
return datetime.datetime.strptime(item, '"%Y-%m-%d"').date()
def date(self, item: Any) -> ISO8601Date:
item = str(item).strip("\"'")
try:
datetime.datetime.strptime(item, "%Y-%m-%d")
except ValueError as e:
raise ValueError(
"Dates are expected to be provided in ISO 8601 date format "
"(YYYY-MM-DD)."
) from e
return {"date": item, "type": "date"}
def string(self, item: Any) -> str:
# Remove escaped quotes

View File

@@ -141,7 +141,7 @@ A logical operation statement takes the form `op(statement1, statement2, ...)`:
Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values.
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\
"""
@@ -175,7 +175,7 @@ A logical operation statement takes the form `op(statement1, statement2, ...)`:
Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values.
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense.

View File

@@ -1,4 +1,3 @@
import datetime
import re
from typing import Any, Callable, Dict, Tuple
@@ -106,9 +105,9 @@ class MyScaleTranslator(Visitor):
value = f"'{value}'" if isinstance(value, str) else value
# convert timestamp for datetime objects
if type(value) is datetime.date:
if isinstance(value, dict) and value.get("type") == "date":
attr = f"parseDateTime32BestEffort({attr})"
value = f"parseDateTime32BestEffort('{value.strftime('%Y-%m-%d')}')"
value = f"parseDateTime32BestEffort('{value['date']}')"
# string pattern match
if comp is Comparator.LIKE:

View File

@@ -1,4 +1,4 @@
from datetime import date, datetime
from datetime import datetime
from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import (
@@ -47,23 +47,26 @@ class WeaviateTranslator(Visitor):
def visit_comparison(self, comparison: Comparison) -> Dict:
value_type = "valueText"
value = comparison.value
if isinstance(comparison.value, bool):
value_type = "valueBoolean"
elif isinstance(comparison.value, float):
value_type = "valueNumber"
elif isinstance(comparison.value, int):
value_type = "valueInt"
elif isinstance(comparison.value, datetime) or isinstance(
comparison.value, date
elif (
isinstance(comparison.value, dict)
and comparison.value.get("type") == "date"
):
value_type = "valueDate"
# ISO 8601 timestamp, formatted as RFC3339
comparison.value = comparison.value.strftime("%Y-%m-%dT%H:%M:%SZ")
date = datetime.strptime(comparison.value["date"], "%Y-%m-%d")
value = date.strftime("%Y-%m-%dT%H:%M:%SZ")
filter = {
"path": [comparison.attribute],
"operator": self._format_func(comparison.comparator),
value_type: value,
}
filter[value_type] = comparison.value
return filter
def visit_structured_query(

View File

@@ -122,3 +122,10 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
expected = DEFAULT_PARSER.parse(arg)
actual = DEFAULT_PARSER.parse(f"{op}({arg})")
assert expected == actual
@pytest.mark.parametrize("x", ('"2022-10-20"', "'2022-10-20'", "2022-10-20"))
def test_parse_date_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value["date"]
assert actual == x.strip("'\"")

View File

@@ -1,4 +1,3 @@
from datetime import date, datetime
from typing import Dict, Tuple
from langchain.chains.query_constructor.ir import (
@@ -45,12 +44,12 @@ def test_visit_comparison_datetime() -> None:
comp = Comparison(
comparator=Comparator.LTE,
attribute="foo",
value=datetime(2023, 9, 13, 4, 20, 0),
value={"type": "date", "date": "2023-09-13"},
)
expected = {
"operator": "LessThanEqual",
"path": ["foo"],
"valueDate": "2023-09-13T04:20:00Z",
"valueDate": "2023-09-13T00:00:00Z",
}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
@@ -58,7 +57,9 @@ def test_visit_comparison_datetime() -> None:
def test_visit_comparison_date() -> None:
comp = Comparison(
comparator=Comparator.LT, attribute="foo", value=date(2023, 9, 13)
comparator=Comparator.LT,
attribute="foo",
value={"type": "date", "date": "2023-09-13"},
)
expected = {
"operator": "LessThan",
@@ -75,7 +76,9 @@ def test_visit_operation() -> None:
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value="hello"),
Comparison(
comparator=Comparator.GTE, attribute="bar", value=date(2023, 9, 13)
comparator=Comparator.GTE,
attribute="bar",
value={"type": "date", "date": "2023-09-13"},
),
Comparison(comparator=Comparator.LTE, attribute="abc", value=1.4),
],