mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
Add query parsing unit tests (#3672)
This commit is contained in:
parent
03c05b15f6
commit
b807a114e4
@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
parsed = parse_json_markdown(text, expected_keys)
|
parsed = parse_json_markdown(text, expected_keys)
|
||||||
if len(parsed["query"]) == 0:
|
if len(parsed["query"]) == 0:
|
||||||
parsed["query"] = " "
|
parsed["query"] = " "
|
||||||
if parsed["filter"] == "NO_FILTER":
|
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
||||||
parsed["filter"] = None
|
parsed["filter"] = None
|
||||||
else:
|
else:
|
||||||
parsed["filter"] = self.ast_parse(parsed["filter"])
|
parsed["filter"] = self.ast_parse(parsed["filter"])
|
||||||
|
@ -20,19 +20,21 @@ GRAMMAR = """
|
|||||||
|
|
||||||
func_call: CNAME "(" [args] ")"
|
func_call: CNAME "(" [args] ")"
|
||||||
|
|
||||||
?value: SIGNED_NUMBER -> number
|
?value: SIGNED_INT -> int
|
||||||
|
| SIGNED_FLOAT -> float
|
||||||
| list
|
| list
|
||||||
| string
|
| string
|
||||||
| "false" -> false
|
| ("false" | "False" | "FALSE") -> false
|
||||||
| "true" -> true
|
| ("true" | "True" | "TRUE") -> true
|
||||||
|
|
||||||
args: expr ("," expr)*
|
args: expr ("," expr)*
|
||||||
string: ESCAPED_STRING
|
string: /'[^']*'/ | ESCAPED_STRING
|
||||||
list: "[" [args] "]"
|
list: "[" [args] "]"
|
||||||
|
|
||||||
%import common.CNAME
|
%import common.CNAME
|
||||||
%import common.SIGNED_NUMBER
|
|
||||||
%import common.ESCAPED_STRING
|
%import common.ESCAPED_STRING
|
||||||
|
%import common.SIGNED_FLOAT
|
||||||
|
%import common.SIGNED_INT
|
||||||
%import common.WS
|
%import common.WS
|
||||||
%ignore WS
|
%ignore WS
|
||||||
"""
|
"""
|
||||||
@ -44,7 +46,7 @@ class QueryTransformer(Transformer):
|
|||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]],
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -93,9 +95,14 @@ class QueryTransformer(Transformer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def list(self, item: Any) -> list:
|
def list(self, item: Any) -> list:
|
||||||
|
if item is None:
|
||||||
|
return []
|
||||||
return list(item)
|
return list(item)
|
||||||
|
|
||||||
def number(self, item: Any) -> float:
|
def int(self, item: Any) -> int:
|
||||||
|
return int(item)
|
||||||
|
|
||||||
|
def float(self, item: Any) -> float:
|
||||||
return float(item)
|
return float(item)
|
||||||
|
|
||||||
def string(self, item: Any) -> str:
|
def string(self, item: Any) -> str:
|
||||||
|
@ -32,7 +32,7 @@ FULL_ANSWER = """\
|
|||||||
{{
|
{{
|
||||||
"query": "teenager love",
|
"query": "teenager love",
|
||||||
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
|
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
|
||||||
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))"
|
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
NO_FILTER_ANSWER = """\
|
NO_FILTER_ANSWER = """\
|
||||||
|
10
poetry.lock
generated
10
poetry.lock
generated
@ -571,7 +571,7 @@ name = "azure-core"
|
|||||||
version = "1.26.4"
|
version = "1.26.4"
|
||||||
description = "Microsoft Azure Core Library for Python"
|
description = "Microsoft Azure Core Library for Python"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"},
|
{file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"},
|
||||||
@ -3488,7 +3488,7 @@ name = "lark"
|
|||||||
version = "1.1.5"
|
version = "1.1.5"
|
||||||
description = "a modern parsing library"
|
description = "a modern parsing library"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
|
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
|
||||||
@ -9393,8 +9393,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
|||||||
cffi = ["cffi (>=1.11)"]
|
cffi = ["cffi (>=1.11)"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||||
azure = ["azure-cosmos", "azure-identity", "openai"]
|
azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"]
|
||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
embeddings = ["sentence-transformers"]
|
embeddings = ["sentence-transformers"]
|
||||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
|
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
|
||||||
@ -9404,4 +9404,4 @@ qdrant = ["qdrant-client"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40"
|
content-hash = "f7ff48dfce65630ea5c67287e91d923be83b9d0d9dd68639afcbc29f5f6f9c5f"
|
||||||
|
@ -99,6 +99,7 @@ pytest-watcher = "^0.2.6"
|
|||||||
freezegun = "^1.2.2"
|
freezegun = "^1.2.2"
|
||||||
responses = "^0.22.0"
|
responses = "^0.22.0"
|
||||||
pytest-asyncio = "^0.20.3"
|
pytest-asyncio = "^0.20.3"
|
||||||
|
lark = "^1.1.5"
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
[tool.poetry.group.test_integration]
|
||||||
optional = true
|
optional = true
|
||||||
|
116
tests/unit_tests/chains/query_constructor/test_parser.py
Normal file
116
tests/unit_tests/chains/query_constructor/test_parser.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
"""Test LLM-generated structured query parsing."""
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import lark
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.query_constructor.ir import (
|
||||||
|
Comparator,
|
||||||
|
Comparison,
|
||||||
|
Operation,
|
||||||
|
Operator,
|
||||||
|
)
|
||||||
|
from langchain.chains.query_constructor.parser import get_parser
|
||||||
|
|
||||||
|
DEFAULT_PARSER = get_parser()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
|
||||||
|
def test_parse_invalid_grammar(x: str) -> None:
|
||||||
|
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
|
||||||
|
DEFAULT_PARSER.parse(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_comparison() -> None:
|
||||||
|
comp = 'gte("foo", 2)'
|
||||||
|
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
|
||||||
|
for input in (
|
||||||
|
comp,
|
||||||
|
comp.replace('"', "'"),
|
||||||
|
comp.replace(" ", ""),
|
||||||
|
comp.replace(" ", " "),
|
||||||
|
comp.replace("(", " ("),
|
||||||
|
comp.replace(",", ", "),
|
||||||
|
comp.replace("2", "2.0"),
|
||||||
|
):
|
||||||
|
actual = DEFAULT_PARSER.parse(input)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_operation() -> None:
|
||||||
|
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
|
||||||
|
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
|
||||||
|
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
|
||||||
|
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
|
||||||
|
for input in (
|
||||||
|
op,
|
||||||
|
op.replace('"', "'"),
|
||||||
|
op.replace(" ", ""),
|
||||||
|
op.replace(" ", " "),
|
||||||
|
op.replace("(", " ("),
|
||||||
|
op.replace(",", ", "),
|
||||||
|
op.replace("25", "250"),
|
||||||
|
):
|
||||||
|
actual = DEFAULT_PARSER.parse(input)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_nested_operation() -> None:
|
||||||
|
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
|
||||||
|
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
|
||||||
|
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
|
||||||
|
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
|
||||||
|
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
|
||||||
|
_not = Operation(operator=Operator.NOT, arguments=[eq4])
|
||||||
|
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
|
||||||
|
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
|
||||||
|
actual = DEFAULT_PARSER.parse(op)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_disallowed_comparator() -> None:
|
||||||
|
parser = get_parser(allowed_comparators=[Comparator.EQ])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parser.parse('gt("a", 2)')
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_disallowed_operator() -> None:
|
||||||
|
parser = get_parser(allowed_operators=[Operator.AND])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parser.parse('not(gt("a", 2))')
|
||||||
|
|
||||||
|
|
||||||
|
def _test_parse_value(x: Any) -> None:
|
||||||
|
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||||
|
actual = parsed.value
|
||||||
|
assert actual == x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
|
||||||
|
def test_parse_int_value(x: int) -> None:
|
||||||
|
_test_parse_value(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
|
||||||
|
def test_parse_float_value(x: float) -> None:
|
||||||
|
_test_parse_value(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
|
||||||
|
def test_parse_list_value(x: list) -> None:
|
||||||
|
_test_parse_value(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
|
||||||
|
def test_parse_string_value(x: str) -> None:
|
||||||
|
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||||
|
actual = parsed.value
|
||||||
|
assert actual == x[1:-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
|
||||||
|
def test_parse_bool_value(x: str) -> None:
|
||||||
|
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||||
|
actual = parsed.value
|
||||||
|
expected = x.lower() == "true"
|
||||||
|
assert actual == expected
|
Loading…
Reference in New Issue
Block a user