mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
QoL improvements to query constructor (#11504)
updating query constructor and self query retriever to - make it easier to pass in examples - validate attributes used in query - remove invalid parts of query - make it easier to get + edit prompt - make query constructor a runnable - make self query retriever use as runnable
This commit is contained in:
parent
eec53fa294
commit
e7a0def1bc
@ -2,11 +2,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List, Optional, Sequence
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.query_constructor.ir import (
|
from langchain.chains.query_constructor.ir import (
|
||||||
Comparator,
|
Comparator,
|
||||||
|
Comparison,
|
||||||
|
FilterDirective,
|
||||||
|
Operation,
|
||||||
Operator,
|
Operator,
|
||||||
StructuredQuery,
|
StructuredQuery,
|
||||||
)
|
)
|
||||||
@ -14,17 +17,21 @@ from langchain.chains.query_constructor.parser import get_parser
|
|||||||
from langchain.chains.query_constructor.prompt import (
|
from langchain.chains.query_constructor.prompt import (
|
||||||
DEFAULT_EXAMPLES,
|
DEFAULT_EXAMPLES,
|
||||||
DEFAULT_PREFIX,
|
DEFAULT_PREFIX,
|
||||||
DEFAULT_SCHEMA,
|
DEFAULT_SCHEMA_PROMPT,
|
||||||
DEFAULT_SUFFIX,
|
DEFAULT_SUFFIX,
|
||||||
EXAMPLE_PROMPT,
|
EXAMPLE_PROMPT,
|
||||||
EXAMPLES_WITH_LIMIT,
|
EXAMPLES_WITH_LIMIT,
|
||||||
SCHEMA_WITH_LIMIT,
|
PREFIX_WITH_DATA_SOURCE,
|
||||||
|
SCHEMA_WITH_LIMIT_PROMPT,
|
||||||
|
SUFFIX_WITHOUT_DATA_SOURCE,
|
||||||
|
USER_SPECIFIED_EXAMPLE_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.schema.runnable import Runnable
|
||||||
|
|
||||||
|
|
||||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||||
@ -59,6 +66,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
cls,
|
cls,
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
|
allowed_attributes: Optional[Sequence[str]] = None,
|
||||||
|
fix_invalid: bool = False,
|
||||||
) -> StructuredQueryOutputParser:
|
) -> StructuredQueryOutputParser:
|
||||||
"""
|
"""
|
||||||
Create a structured query output parser from components.
|
Create a structured query output parser from components.
|
||||||
@ -70,13 +79,73 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
Returns:
|
Returns:
|
||||||
a structured query output parser
|
a structured query output parser
|
||||||
"""
|
"""
|
||||||
ast_parser = get_parser(
|
ast_parse: Callable
|
||||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
if fix_invalid:
|
||||||
|
|
||||||
|
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
||||||
|
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
|
||||||
|
fixed = fix_filter_directive(
|
||||||
|
filter,
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
)
|
)
|
||||||
return cls(ast_parse=ast_parser.parse)
|
return fixed
|
||||||
|
|
||||||
|
else:
|
||||||
|
ast_parse = get_parser(
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
|
).parse
|
||||||
|
return cls(ast_parse=ast_parse)
|
||||||
|
|
||||||
|
|
||||||
def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
|
def fix_filter_directive(
|
||||||
|
filter: Optional[FilterDirective],
|
||||||
|
*,
|
||||||
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
|
allowed_attributes: Optional[Sequence[str]] = None,
|
||||||
|
) -> Optional[FilterDirective]:
|
||||||
|
if (
|
||||||
|
not (allowed_comparators or allowed_operators or allowed_attributes)
|
||||||
|
) or not filter:
|
||||||
|
return filter
|
||||||
|
|
||||||
|
elif isinstance(filter, Comparison):
|
||||||
|
if allowed_comparators and filter.comparator not in allowed_comparators:
|
||||||
|
return None
|
||||||
|
if allowed_attributes and filter.attribute not in allowed_attributes:
|
||||||
|
return None
|
||||||
|
return filter
|
||||||
|
elif isinstance(filter, Operation):
|
||||||
|
if allowed_operators and filter.operator not in allowed_operators:
|
||||||
|
return None
|
||||||
|
args = [
|
||||||
|
fix_filter_directive(
|
||||||
|
arg,
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
|
)
|
||||||
|
for arg in filter.arguments
|
||||||
|
]
|
||||||
|
args = [arg for arg in args if arg is not None]
|
||||||
|
if not args:
|
||||||
|
return None
|
||||||
|
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||||
|
return args[0]
|
||||||
|
else:
|
||||||
|
return Operation(
|
||||||
|
operator=filter.operator,
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
||||||
info_dicts = {}
|
info_dicts = {}
|
||||||
for i in info:
|
for i in info:
|
||||||
i_dict = dict(i)
|
i_dict = dict(i)
|
||||||
@ -84,56 +153,90 @@ def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
|
|||||||
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
||||||
|
|
||||||
|
|
||||||
def _get_prompt(
|
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
|
||||||
|
examples = []
|
||||||
|
for i, (_input, output) in enumerate(input_output_pairs):
|
||||||
|
structured_request = (
|
||||||
|
json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}")
|
||||||
|
)
|
||||||
|
example = {
|
||||||
|
"i": i + 1,
|
||||||
|
"user_query": _input,
|
||||||
|
"structured_request": structured_request,
|
||||||
|
}
|
||||||
|
examples.append(example)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_constructor_prompt(
|
||||||
document_contents: str,
|
document_contents: str,
|
||||||
attribute_info: Sequence[AttributeInfo],
|
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||||
examples: Optional[List] = None,
|
*,
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
examples: Optional[Sequence] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||||
|
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||||
enable_limit: bool = False,
|
enable_limit: bool = False,
|
||||||
|
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
|
"""Create query construction prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_contents: The contents of the document to be queried.
|
||||||
|
attribute_info: A list of AttributeInfo objects describing
|
||||||
|
the attributes of the document.
|
||||||
|
examples: Optional list of examples to use for the chain.
|
||||||
|
allowed_comparators: Sequence of allowed comparators.
|
||||||
|
allowed_operators: Sequence of allowed operators.
|
||||||
|
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||||
|
schema_prompt: Prompt for describing query schema. Should have string input
|
||||||
|
variables allowed_comparators and allowed_operators.
|
||||||
|
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
|
||||||
|
"""
|
||||||
|
default_schema_prompt = (
|
||||||
|
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
|
||||||
|
)
|
||||||
|
schema_prompt = schema_prompt or default_schema_prompt
|
||||||
attribute_str = _format_attribute_info(attribute_info)
|
attribute_str = _format_attribute_info(attribute_info)
|
||||||
allowed_comparators = allowed_comparators or list(Comparator)
|
schema = schema_prompt.format(
|
||||||
allowed_operators = allowed_operators or list(Operator)
|
|
||||||
if enable_limit:
|
|
||||||
schema = SCHEMA_WITH_LIMIT.format(
|
|
||||||
allowed_comparators=" | ".join(allowed_comparators),
|
allowed_comparators=" | ".join(allowed_comparators),
|
||||||
allowed_operators=" | ".join(allowed_operators),
|
allowed_operators=" | ".join(allowed_operators),
|
||||||
)
|
)
|
||||||
|
if examples and isinstance(examples[0], tuple):
|
||||||
examples = examples or EXAMPLES_WITH_LIMIT
|
examples = construct_examples(examples)
|
||||||
|
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT
|
||||||
|
prefix = PREFIX_WITH_DATA_SOURCE.format(
|
||||||
|
schema=schema, content=document_contents, attributes=attribute_str
|
||||||
|
)
|
||||||
|
suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1)
|
||||||
else:
|
else:
|
||||||
schema = DEFAULT_SCHEMA.format(
|
examples = examples or (
|
||||||
allowed_comparators=" | ".join(allowed_comparators),
|
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES
|
||||||
allowed_operators=" | ".join(allowed_operators),
|
|
||||||
)
|
)
|
||||||
|
example_prompt = EXAMPLE_PROMPT
|
||||||
examples = examples or DEFAULT_EXAMPLES
|
|
||||||
prefix = DEFAULT_PREFIX.format(schema=schema)
|
prefix = DEFAULT_PREFIX.format(schema=schema)
|
||||||
suffix = DEFAULT_SUFFIX.format(
|
suffix = DEFAULT_SUFFIX.format(
|
||||||
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
||||||
)
|
)
|
||||||
output_parser = StructuredQueryOutputParser.from_components(
|
|
||||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
|
||||||
)
|
|
||||||
return FewShotPromptTemplate(
|
return FewShotPromptTemplate(
|
||||||
examples=examples,
|
examples=list(examples),
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
example_prompt=example_prompt,
|
||||||
input_variables=["query"],
|
input_variables=["query"],
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
output_parser=output_parser,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_query_constructor_chain(
|
def load_query_constructor_chain(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
document_contents: str,
|
document_contents: str,
|
||||||
attribute_info: List[AttributeInfo],
|
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||||
examples: Optional[List] = None,
|
examples: Optional[List] = None,
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||||
enable_limit: bool = False,
|
enable_limit: bool = False,
|
||||||
|
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMChain:
|
) -> LLMChain:
|
||||||
"""Load a query constructor chain.
|
"""Load a query constructor chain.
|
||||||
@ -141,25 +244,95 @@ def load_query_constructor_chain(
|
|||||||
Args:
|
Args:
|
||||||
llm: BaseLanguageModel to use for the chain.
|
llm: BaseLanguageModel to use for the chain.
|
||||||
document_contents: The contents of the document to be queried.
|
document_contents: The contents of the document to be queried.
|
||||||
attribute_info: A list of AttributeInfo objects describing
|
attribute_info: Sequence of attributes in the document.
|
||||||
the attributes of the document.
|
|
||||||
examples: Optional list of examples to use for the chain.
|
examples: Optional list of examples to use for the chain.
|
||||||
allowed_comparators: An optional list of allowed comparators.
|
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||||
allowed_operators: An optional list of allowed operators.
|
Comparators.
|
||||||
|
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
|
||||||
enable_limit: Whether to enable the limit operator. Defaults to False.
|
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||||
**kwargs:
|
schema_prompt: Prompt for describing query schema. Should have string input
|
||||||
|
variables allowed_comparators and allowed_operators.
|
||||||
|
**kwargs: Arbitrary named params to pass to LLMChain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LLMChain that can be used to construct queries.
|
A LLMChain that can be used to construct queries.
|
||||||
"""
|
"""
|
||||||
prompt = _get_prompt(
|
prompt = get_query_constructor_prompt(
|
||||||
document_contents,
|
document_contents,
|
||||||
attribute_info,
|
attribute_info,
|
||||||
examples=examples,
|
examples=examples,
|
||||||
allowed_comparators=allowed_comparators,
|
allowed_comparators=allowed_comparators,
|
||||||
allowed_operators=allowed_operators,
|
allowed_operators=allowed_operators,
|
||||||
enable_limit=enable_limit,
|
enable_limit=enable_limit,
|
||||||
|
schema_prompt=schema_prompt,
|
||||||
)
|
)
|
||||||
return LLMChain(
|
allowed_attributes = []
|
||||||
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
|
for ainfo in attribute_info:
|
||||||
|
allowed_attributes.append(
|
||||||
|
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
|
||||||
)
|
)
|
||||||
|
output_parser = StructuredQueryOutputParser.from_components(
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
|
)
|
||||||
|
# For backwards compatibility.
|
||||||
|
prompt.output_parser = output_parser
|
||||||
|
return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_query_constructor_runnable(
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
document_contents: str,
|
||||||
|
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||||
|
*,
|
||||||
|
examples: Optional[Sequence] = None,
|
||||||
|
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||||
|
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||||
|
enable_limit: bool = False,
|
||||||
|
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
fix_invalid: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable:
|
||||||
|
"""Load a query constructor runnable chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: BaseLanguageModel to use for the chain.
|
||||||
|
document_contents: The contents of the document to be queried.
|
||||||
|
attribute_info: Sequence of attributes in the document.
|
||||||
|
examples: Optional list of examples to use for the chain.
|
||||||
|
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||||
|
Comparators.
|
||||||
|
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
|
||||||
|
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||||
|
schema_prompt: Prompt for describing query schema. Should have string input
|
||||||
|
variables allowed_comparators and allowed_operators.
|
||||||
|
fix_invalid: Whether to fix invalid filter directives by ignoring invalid
|
||||||
|
operators, comparators and attributes.
|
||||||
|
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that can be used to construct queries.
|
||||||
|
"""
|
||||||
|
prompt = get_query_constructor_prompt(
|
||||||
|
document_contents,
|
||||||
|
attribute_info,
|
||||||
|
examples=examples,
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
enable_limit=enable_limit,
|
||||||
|
schema_prompt=schema_prompt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
allowed_attributes = []
|
||||||
|
for ainfo in attribute_info:
|
||||||
|
allowed_attributes.append(
|
||||||
|
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
|
||||||
|
)
|
||||||
|
output_parser = StructuredQueryOutputParser.from_components(
|
||||||
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
|
fix_invalid=fix_invalid,
|
||||||
|
)
|
||||||
|
return prompt | llm | output_parser
|
||||||
|
@ -61,11 +61,13 @@ class QueryTransformer(Transformer):
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
|
allowed_attributes: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.allowed_comparators = allowed_comparators
|
self.allowed_comparators = allowed_comparators
|
||||||
self.allowed_operators = allowed_operators
|
self.allowed_operators = allowed_operators
|
||||||
|
self.allowed_attributes = allowed_attributes
|
||||||
|
|
||||||
def program(self, *items: Any) -> tuple:
|
def program(self, *items: Any) -> tuple:
|
||||||
return items
|
return items
|
||||||
@ -73,6 +75,11 @@ class QueryTransformer(Transformer):
|
|||||||
def func_call(self, func_name: Any, args: list) -> FilterDirective:
|
def func_call(self, func_name: Any, args: list) -> FilterDirective:
|
||||||
func = self._match_func_name(str(func_name))
|
func = self._match_func_name(str(func_name))
|
||||||
if isinstance(func, Comparator):
|
if isinstance(func, Comparator):
|
||||||
|
if self.allowed_attributes and args[0] not in self.allowed_attributes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid attributes {args[0]}. Allowed attributes are "
|
||||||
|
f"{self.allowed_attributes}"
|
||||||
|
)
|
||||||
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
||||||
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||||
return args[0]
|
return args[0]
|
||||||
@ -134,6 +141,7 @@ class QueryTransformer(Transformer):
|
|||||||
def get_parser(
|
def get_parser(
|
||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
|
allowed_attributes: Optional[Sequence[str]] = None,
|
||||||
) -> Lark:
|
) -> Lark:
|
||||||
"""
|
"""
|
||||||
Returns a parser for the query language.
|
Returns a parser for the query language.
|
||||||
@ -151,6 +159,8 @@ def get_parser(
|
|||||||
"Cannot import lark, please install it with 'pip install lark'."
|
"Cannot import lark, please install it with 'pip install lark'."
|
||||||
)
|
)
|
||||||
transformer = QueryTransformer(
|
transformer = QueryTransformer(
|
||||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
allowed_comparators=allowed_comparators,
|
||||||
|
allowed_operators=allowed_operators,
|
||||||
|
allowed_attributes=allowed_attributes,
|
||||||
)
|
)
|
||||||
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")
|
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")
|
||||||
|
@ -3,36 +3,31 @@ from langchain.prompts import PromptTemplate
|
|||||||
|
|
||||||
SONG_DATA_SOURCE = """\
|
SONG_DATA_SOURCE = """\
|
||||||
```json
|
```json
|
||||||
{
|
{{
|
||||||
"content": "Lyrics of a song",
|
"content": "Lyrics of a song",
|
||||||
"attributes": {
|
"attributes": {{
|
||||||
"artist": {
|
"artist": {{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Name of the song artist"
|
"description": "Name of the song artist"
|
||||||
},
|
}},
|
||||||
"length": {
|
"length": {{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Length of the song in seconds"
|
"description": "Length of the song in seconds"
|
||||||
},
|
}},
|
||||||
"genre": {
|
"genre": {{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The song genre, one of \"pop\", \"rock\" or \"rap\""
|
"description": "The song genre, one of \"pop\", \"rock\" or \"rap\""
|
||||||
}
|
}}
|
||||||
}
|
}}
|
||||||
}
|
}}
|
||||||
```\
|
```\
|
||||||
""".replace(
|
"""
|
||||||
"{", "{{"
|
|
||||||
).replace(
|
|
||||||
"}", "}}"
|
|
||||||
)
|
|
||||||
|
|
||||||
FULL_ANSWER = """\
|
FULL_ANSWER = """\
|
||||||
```json
|
```json
|
||||||
{{
|
{{
|
||||||
"query": "teenager love",
|
"query": "teenager love",
|
||||||
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
|
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
||||||
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
|
||||||
}}
|
}}
|
||||||
```\
|
```\
|
||||||
"""
|
"""
|
||||||
@ -104,16 +99,24 @@ Structured Request:
|
|||||||
{structured_request}
|
{structured_request}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
EXAMPLE_PROMPT = PromptTemplate(
|
EXAMPLE_PROMPT = PromptTemplate.from_template(EXAMPLE_PROMPT_TEMPLATE)
|
||||||
input_variables=["i", "data_source", "user_query", "structured_request"],
|
|
||||||
template=EXAMPLE_PROMPT_TEMPLATE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
USER_SPECIFIED_EXAMPLE_PROMPT = PromptTemplate.from_template(
|
||||||
|
"""\
|
||||||
|
<< Example {i}. >>
|
||||||
|
User Query:
|
||||||
|
{user_query}
|
||||||
|
|
||||||
|
Structured Request:
|
||||||
|
```json
|
||||||
|
{structured_request}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_SCHEMA = """\
|
DEFAULT_SCHEMA = """\
|
||||||
<< Structured Request Schema >>
|
<< Structured Request Schema >>
|
||||||
When responding use a markdown code snippet with a JSON object formatted in the \
|
When responding use a markdown code snippet with a JSON object formatted in the following schema:
|
||||||
following schema:
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{{{{
|
{{{{
|
||||||
@ -122,11 +125,9 @@ following schema:
|
|||||||
}}}}
|
}}}}
|
||||||
```
|
```
|
||||||
|
|
||||||
The query string should contain only text that is expected to match the contents of \
|
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||||
documents. Any conditions in the filter should not be mentioned in the query as well.
|
|
||||||
|
|
||||||
A logical condition statement is composed of one or more comparison and logical \
|
A logical condition statement is composed of one or more comparison and logical operation statements.
|
||||||
operation statements.
|
|
||||||
|
|
||||||
A comparison statement takes the form: `comp(attr, val)`:
|
A comparison statement takes the form: `comp(attr, val)`:
|
||||||
- `comp` ({allowed_comparators}): comparator
|
- `comp` ({allowed_comparators}): comparator
|
||||||
@ -135,24 +136,20 @@ A comparison statement takes the form: `comp(attr, val)`:
|
|||||||
|
|
||||||
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
||||||
- `op` ({allowed_operators}): logical operator
|
- `op` ({allowed_operators}): logical operator
|
||||||
- `statement1`, `statement2`, ... (comparison statements or logical operation \
|
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
|
||||||
statements): one or more statements to apply the operation to
|
|
||||||
|
|
||||||
Make sure that you only use the comparators and logical operators listed above and \
|
Make sure that you only use the comparators and logical operators listed above and no others.
|
||||||
no others.
|
|
||||||
Make sure that filters only refer to attributes that exist in the data source.
|
Make sure that filters only refer to attributes that exist in the data source.
|
||||||
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
||||||
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
||||||
Make sure that filters take into account the descriptions of attributes and only make \
|
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
|
||||||
comparisons that are feasible given the type of data being stored.
|
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\
|
||||||
Make sure that filters are only used as needed. If there are no filters that should be \
|
|
||||||
applied return "NO_FILTER" for the filter value.\
|
|
||||||
"""
|
"""
|
||||||
|
DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA)
|
||||||
|
|
||||||
SCHEMA_WITH_LIMIT = """\
|
SCHEMA_WITH_LIMIT = """\
|
||||||
<< Structured Request Schema >>
|
<< Structured Request Schema >>
|
||||||
When responding use a markdown code snippet with a JSON object formatted in the \
|
When responding use a markdown code snippet with a JSON object formatted in the following schema:
|
||||||
following schema:
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{{{{
|
{{{{
|
||||||
@ -162,11 +159,9 @@ following schema:
|
|||||||
}}}}
|
}}}}
|
||||||
```
|
```
|
||||||
|
|
||||||
The query string should contain only text that is expected to match the contents of \
|
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||||
documents. Any conditions in the filter should not be mentioned in the query as well.
|
|
||||||
|
|
||||||
A logical condition statement is composed of one or more comparison and logical \
|
A logical condition statement is composed of one or more comparison and logical operation statements.
|
||||||
operation statements.
|
|
||||||
|
|
||||||
A comparison statement takes the form: `comp(attr, val)`:
|
A comparison statement takes the form: `comp(attr, val)`:
|
||||||
- `comp` ({allowed_comparators}): comparator
|
- `comp` ({allowed_comparators}): comparator
|
||||||
@ -175,20 +170,17 @@ A comparison statement takes the form: `comp(attr, val)`:
|
|||||||
|
|
||||||
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
||||||
- `op` ({allowed_operators}): logical operator
|
- `op` ({allowed_operators}): logical operator
|
||||||
- `statement1`, `statement2`, ... (comparison statements or logical operation \
|
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
|
||||||
statements): one or more statements to apply the operation to
|
|
||||||
|
|
||||||
Make sure that you only use the comparators and logical operators listed above and \
|
Make sure that you only use the comparators and logical operators listed above and no others.
|
||||||
no others.
|
|
||||||
Make sure that filters only refer to attributes that exist in the data source.
|
Make sure that filters only refer to attributes that exist in the data source.
|
||||||
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
||||||
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
||||||
Make sure that filters take into account the descriptions of attributes and only make \
|
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
|
||||||
comparisons that are feasible given the type of data being stored.
|
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
|
||||||
Make sure that filters are only used as needed. If there are no filters that should be \
|
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense.
|
||||||
applied return "NO_FILTER" for the filter value.
|
|
||||||
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
|
|
||||||
"""
|
"""
|
||||||
|
SCHEMA_WITH_LIMIT_PROMPT = PromptTemplate.from_template(SCHEMA_WITH_LIMIT)
|
||||||
|
|
||||||
DEFAULT_PREFIX = """\
|
DEFAULT_PREFIX = """\
|
||||||
Your goal is to structure the user's query to match the request schema provided below.
|
Your goal is to structure the user's query to match the request schema provided below.
|
||||||
@ -196,6 +188,20 @@ Your goal is to structure the user's query to match the request schema provided
|
|||||||
{schema}\
|
{schema}\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
PREFIX_WITH_DATA_SOURCE = (
|
||||||
|
DEFAULT_PREFIX
|
||||||
|
+ """
|
||||||
|
|
||||||
|
<< Data Source >>
|
||||||
|
```json
|
||||||
|
{{{{
|
||||||
|
"content": "{content}",
|
||||||
|
"attributes": {attributes}
|
||||||
|
}}}}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_SUFFIX = """\
|
DEFAULT_SUFFIX = """\
|
||||||
<< Example {i}. >>
|
<< Example {i}. >>
|
||||||
Data Source:
|
Data Source:
|
||||||
@ -211,3 +217,11 @@ User Query:
|
|||||||
|
|
||||||
Structured Request:
|
Structured Request:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SUFFIX_WITHOUT_DATA_SOURCE = """\
|
||||||
|
<< Example {i}. >>
|
||||||
|
User Query:
|
||||||
|
{{query}}
|
||||||
|
|
||||||
|
Structured Request:
|
||||||
|
"""
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
"""Retriever that generates and executes structured queries over its own data source."""
|
"""Retriever that generates and executes structured queries over its own data source."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
)
|
)
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains.query_constructor.base import load_query_constructor_runnable
|
||||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
|
||||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
@ -27,6 +26,7 @@ from langchain.retrievers.self_query.vectara import VectaraTranslator
|
|||||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.schema.runnable import Runnable
|
||||||
from langchain.vectorstores import (
|
from langchain.vectorstores import (
|
||||||
Chroma,
|
Chroma,
|
||||||
DashVector,
|
DashVector,
|
||||||
@ -86,8 +86,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
|
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
"""The underlying vector store from which documents will be retrieved."""
|
"""The underlying vector store from which documents will be retrieved."""
|
||||||
llm_chain: LLMChain
|
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain")
|
||||||
"""The LLMChain for generating the vector store queries."""
|
"""The query constructor chain for generating the vector store queries.
|
||||||
|
|
||||||
|
llm_chain is legacy name kept for backwards compatibility."""
|
||||||
search_type: str = "similarity"
|
search_type: str = "similarity"
|
||||||
"""The search type to perform on the vector store."""
|
"""The search type to perform on the vector store."""
|
||||||
search_kwargs: dict = Field(default_factory=dict)
|
search_kwargs: dict = Field(default_factory=dict)
|
||||||
@ -103,6 +105,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_translator(cls, values: Dict) -> Dict:
|
def validate_translator(cls, values: Dict) -> Dict:
|
||||||
@ -113,23 +116,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_structured_query(
|
@property
|
||||||
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
|
def llm_chain(self) -> Runnable:
|
||||||
) -> StructuredQuery:
|
"""llm_chain is legacy name kept for backwards compatibility."""
|
||||||
structured_query = cast(
|
return self.query_constructor
|
||||||
StructuredQuery,
|
|
||||||
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
|
|
||||||
)
|
|
||||||
return structured_query
|
|
||||||
|
|
||||||
async def _aget_structured_query(
|
|
||||||
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> StructuredQuery:
|
|
||||||
structured_query = cast(
|
|
||||||
StructuredQuery,
|
|
||||||
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
|
|
||||||
)
|
|
||||||
return structured_query
|
|
||||||
|
|
||||||
def _prepare_query(
|
def _prepare_query(
|
||||||
self, query: str, structured_query: StructuredQuery
|
self, query: str, structured_query: StructuredQuery
|
||||||
@ -167,8 +157,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
structured_query = self.query_constructor.invoke(
|
||||||
structured_query = self._get_structured_query(inputs, run_manager)
|
{"query": query}, config={"callbacks": run_manager.get_child()}
|
||||||
|
)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
@ -186,8 +177,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
structured_query = await self.query_constructor.ainvoke(
|
||||||
structured_query = await self._aget_structured_query(inputs, run_manager)
|
{"query": query}, config={"callbacks": run_manager.get_child()}
|
||||||
|
)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
@ -200,7 +192,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
vectorstore: VectorStore,
|
vectorstore: VectorStore,
|
||||||
document_contents: str,
|
document_contents: str,
|
||||||
metadata_field_info: List[AttributeInfo],
|
metadata_field_info: Sequence[Union[AttributeInfo, dict]],
|
||||||
structured_query_translator: Optional[Visitor] = None,
|
structured_query_translator: Optional[Visitor] = None,
|
||||||
chain_kwargs: Optional[Dict] = None,
|
chain_kwargs: Optional[Dict] = None,
|
||||||
enable_limit: bool = False,
|
enable_limit: bool = False,
|
||||||
@ -219,7 +211,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
chain_kwargs[
|
chain_kwargs[
|
||||||
"allowed_operators"
|
"allowed_operators"
|
||||||
] = structured_query_translator.allowed_operators
|
] = structured_query_translator.allowed_operators
|
||||||
llm_chain = load_query_constructor_chain(
|
query_constructor = load_query_constructor_runnable(
|
||||||
llm,
|
llm,
|
||||||
document_contents,
|
document_contents,
|
||||||
metadata_field_info,
|
metadata_field_info,
|
||||||
@ -227,7 +219,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
**chain_kwargs,
|
**chain_kwargs,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
llm_chain=llm_chain,
|
query_constructor=query_constructor,
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
use_original_query=use_original_query,
|
use_original_query=use_original_query,
|
||||||
structured_query_translator=structured_query_translator,
|
structured_query_translator=structured_query_translator,
|
||||||
|
Loading…
Reference in New Issue
Block a user