mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
Add query parsing unit tests (#3672)
This commit is contained in:
parent
03c05b15f6
commit
b807a114e4
@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
parsed = parse_json_markdown(text, expected_keys)
|
||||
if len(parsed["query"]) == 0:
|
||||
parsed["query"] = " "
|
||||
if parsed["filter"] == "NO_FILTER":
|
||||
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
||||
parsed["filter"] = None
|
||||
else:
|
||||
parsed["filter"] = self.ast_parse(parsed["filter"])
|
||||
|
@ -20,19 +20,21 @@ GRAMMAR = """
|
||||
|
||||
func_call: CNAME "(" [args] ")"
|
||||
|
||||
?value: SIGNED_NUMBER -> number
|
||||
?value: SIGNED_INT -> int
|
||||
| SIGNED_FLOAT -> float
|
||||
| list
|
||||
| string
|
||||
| "false" -> false
|
||||
| "true" -> true
|
||||
| ("false" | "False" | "FALSE") -> false
|
||||
| ("true" | "True" | "TRUE") -> true
|
||||
|
||||
args: expr ("," expr)*
|
||||
string: ESCAPED_STRING
|
||||
string: /'[^']*'/ | ESCAPED_STRING
|
||||
list: "[" [args] "]"
|
||||
|
||||
%import common.CNAME
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.ESCAPED_STRING
|
||||
%import common.SIGNED_FLOAT
|
||||
%import common.SIGNED_INT
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
@ -44,7 +46,7 @@ class QueryTransformer(Transformer):
|
||||
self,
|
||||
*args: Any,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]],
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -93,9 +95,14 @@ class QueryTransformer(Transformer):
|
||||
return True
|
||||
|
||||
def list(self, item: Any) -> list:
|
||||
if item is None:
|
||||
return []
|
||||
return list(item)
|
||||
|
||||
def number(self, item: Any) -> float:
|
||||
def int(self, item: Any) -> int:
|
||||
return int(item)
|
||||
|
||||
def float(self, item: Any) -> float:
|
||||
return float(item)
|
||||
|
||||
def string(self, item: Any) -> str:
|
||||
|
@ -32,7 +32,7 @@ FULL_ANSWER = """\
|
||||
{{
|
||||
"query": "teenager love",
|
||||
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
|
||||
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))"
|
||||
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
||||
}}"""
|
||||
|
||||
NO_FILTER_ANSWER = """\
|
||||
|
10
poetry.lock
generated
10
poetry.lock
generated
@ -571,7 +571,7 @@ name = "azure-core"
|
||||
version = "1.26.4"
|
||||
description = "Microsoft Azure Core Library for Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"},
|
||||
@ -3488,7 +3488,7 @@ name = "lark"
|
||||
version = "1.1.5"
|
||||
description = "a modern parsing library"
|
||||
category = "main"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
|
||||
@ -9393,8 +9393,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
azure = ["azure-cosmos", "azure-identity", "openai"]
|
||||
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"]
|
||||
cohere = ["cohere"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
|
||||
@ -9404,4 +9404,4 @@ qdrant = ["qdrant-client"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40"
|
||||
content-hash = "f7ff48dfce65630ea5c67287e91d923be83b9d0d9dd68639afcbc29f5f6f9c5f"
|
||||
|
@ -99,6 +99,7 @@ pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
116
tests/unit_tests/chains/query_constructor/test_parser.py
Normal file
116
tests/unit_tests/chains/query_constructor/test_parser.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Test LLM-generated structured query parsing."""
|
||||
from typing import Any, cast
|
||||
|
||||
import lark
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
)
|
||||
from langchain.chains.query_constructor.parser import get_parser
|
||||
|
||||
DEFAULT_PARSER = get_parser()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
|
||||
def test_parse_invalid_grammar(x: str) -> None:
|
||||
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
|
||||
DEFAULT_PARSER.parse(x)
|
||||
|
||||
|
||||
def test_parse_comparison() -> None:
|
||||
comp = 'gte("foo", 2)'
|
||||
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
|
||||
for input in (
|
||||
comp,
|
||||
comp.replace('"', "'"),
|
||||
comp.replace(" ", ""),
|
||||
comp.replace(" ", " "),
|
||||
comp.replace("(", " ("),
|
||||
comp.replace(",", ", "),
|
||||
comp.replace("2", "2.0"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_operation() -> None:
|
||||
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
|
||||
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
|
||||
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
|
||||
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
|
||||
for input in (
|
||||
op,
|
||||
op.replace('"', "'"),
|
||||
op.replace(" ", ""),
|
||||
op.replace(" ", " "),
|
||||
op.replace("(", " ("),
|
||||
op.replace(",", ", "),
|
||||
op.replace("25", "250"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_nested_operation() -> None:
|
||||
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
|
||||
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
|
||||
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
|
||||
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
|
||||
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
|
||||
_not = Operation(operator=Operator.NOT, arguments=[eq4])
|
||||
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
|
||||
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
|
||||
actual = DEFAULT_PARSER.parse(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_disallowed_comparator() -> None:
|
||||
parser = get_parser(allowed_comparators=[Comparator.EQ])
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse('gt("a", 2)')
|
||||
|
||||
|
||||
def test_parse_disallowed_operator() -> None:
|
||||
parser = get_parser(allowed_operators=[Operator.AND])
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse('not(gt("a", 2))')
|
||||
|
||||
|
||||
def _test_parse_value(x: Any) -> None:
|
||||
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||
actual = parsed.value
|
||||
assert actual == x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
|
||||
def test_parse_int_value(x: int) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
|
||||
def test_parse_float_value(x: float) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
|
||||
def test_parse_list_value(x: list) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
|
||||
def test_parse_string_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
assert actual == x[1:-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
|
||||
def test_parse_bool_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
expected = x.lower() == "true"
|
||||
assert actual == expected
|
Loading…
Reference in New Issue
Block a user