mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
QoL improvements to query constructor (#11504)
updating query constructor and self query retriever to - make it easier to pass in examples - validate attributes used in query - remove invalid parts of query - make it easier to get + edit prompt - make query constructor a runnable - make self query retriever use as runnable
This commit is contained in:
parent
eec53fa294
commit
e7a0def1bc
@ -2,11 +2,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
FilterDirective,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
@ -14,17 +17,21 @@ from langchain.chains.query_constructor.parser import get_parser
|
||||
from langchain.chains.query_constructor.prompt import (
|
||||
DEFAULT_EXAMPLES,
|
||||
DEFAULT_PREFIX,
|
||||
DEFAULT_SCHEMA,
|
||||
DEFAULT_SCHEMA_PROMPT,
|
||||
DEFAULT_SUFFIX,
|
||||
EXAMPLE_PROMPT,
|
||||
EXAMPLES_WITH_LIMIT,
|
||||
SCHEMA_WITH_LIMIT,
|
||||
PREFIX_WITH_DATA_SOURCE,
|
||||
SCHEMA_WITH_LIMIT_PROMPT,
|
||||
SUFFIX_WITHOUT_DATA_SOURCE,
|
||||
USER_SPECIFIED_EXAMPLE_PROMPT,
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable import Runnable
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
@ -59,6 +66,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
cls,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_attributes: Optional[Sequence[str]] = None,
|
||||
fix_invalid: bool = False,
|
||||
) -> StructuredQueryOutputParser:
|
||||
"""
|
||||
Create a structured query output parser from components.
|
||||
@ -70,13 +79,73 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
Returns:
|
||||
a structured query output parser
|
||||
"""
|
||||
ast_parser = get_parser(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return cls(ast_parse=ast_parser.parse)
|
||||
ast_parse: Callable
|
||||
if fix_invalid:
|
||||
|
||||
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
||||
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
|
||||
fixed = fix_filter_directive(
|
||||
filter,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
)
|
||||
return fixed
|
||||
|
||||
else:
|
||||
ast_parse = get_parser(
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
).parse
|
||||
return cls(ast_parse=ast_parse)
|
||||
|
||||
|
||||
def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
|
||||
def fix_filter_directive(
|
||||
filter: Optional[FilterDirective],
|
||||
*,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_attributes: Optional[Sequence[str]] = None,
|
||||
) -> Optional[FilterDirective]:
|
||||
if (
|
||||
not (allowed_comparators or allowed_operators or allowed_attributes)
|
||||
) or not filter:
|
||||
return filter
|
||||
|
||||
elif isinstance(filter, Comparison):
|
||||
if allowed_comparators and filter.comparator not in allowed_comparators:
|
||||
return None
|
||||
if allowed_attributes and filter.attribute not in allowed_attributes:
|
||||
return None
|
||||
return filter
|
||||
elif isinstance(filter, Operation):
|
||||
if allowed_operators and filter.operator not in allowed_operators:
|
||||
return None
|
||||
args = [
|
||||
fix_filter_directive(
|
||||
arg,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
)
|
||||
for arg in filter.arguments
|
||||
]
|
||||
args = [arg for arg in args if arg is not None]
|
||||
if not args:
|
||||
return None
|
||||
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||
return args[0]
|
||||
else:
|
||||
return Operation(
|
||||
operator=filter.operator,
|
||||
arguments=args,
|
||||
)
|
||||
else:
|
||||
return filter
|
||||
|
||||
|
||||
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
||||
info_dicts = {}
|
||||
for i in info:
|
||||
i_dict = dict(i)
|
||||
@ -84,56 +153,90 @@ def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
|
||||
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
||||
|
||||
|
||||
def _get_prompt(
|
||||
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
|
||||
examples = []
|
||||
for i, (_input, output) in enumerate(input_output_pairs):
|
||||
structured_request = (
|
||||
json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}")
|
||||
)
|
||||
example = {
|
||||
"i": i + 1,
|
||||
"user_query": _input,
|
||||
"structured_request": structured_request,
|
||||
}
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
def get_query_constructor_prompt(
|
||||
document_contents: str,
|
||||
attribute_info: Sequence[AttributeInfo],
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||
*,
|
||||
examples: Optional[Sequence] = None,
|
||||
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||
enable_limit: bool = False,
|
||||
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> BasePromptTemplate:
|
||||
"""Create query construction prompt.
|
||||
|
||||
Args:
|
||||
document_contents: The contents of the document to be queried.
|
||||
attribute_info: A list of AttributeInfo objects describing
|
||||
the attributes of the document.
|
||||
examples: Optional list of examples to use for the chain.
|
||||
allowed_comparators: Sequence of allowed comparators.
|
||||
allowed_operators: Sequence of allowed operators.
|
||||
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||
schema_prompt: Prompt for describing query schema. Should have string input
|
||||
variables allowed_comparators and allowed_operators.
|
||||
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
|
||||
"""
|
||||
default_schema_prompt = (
|
||||
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
|
||||
)
|
||||
schema_prompt = schema_prompt or default_schema_prompt
|
||||
attribute_str = _format_attribute_info(attribute_info)
|
||||
allowed_comparators = allowed_comparators or list(Comparator)
|
||||
allowed_operators = allowed_operators or list(Operator)
|
||||
if enable_limit:
|
||||
schema = SCHEMA_WITH_LIMIT.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
schema = schema_prompt.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
)
|
||||
if examples and isinstance(examples[0], tuple):
|
||||
examples = construct_examples(examples)
|
||||
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT
|
||||
prefix = PREFIX_WITH_DATA_SOURCE.format(
|
||||
schema=schema, content=document_contents, attributes=attribute_str
|
||||
)
|
||||
|
||||
examples = examples or EXAMPLES_WITH_LIMIT
|
||||
suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1)
|
||||
else:
|
||||
schema = DEFAULT_SCHEMA.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
examples = examples or (
|
||||
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES
|
||||
)
|
||||
example_prompt = EXAMPLE_PROMPT
|
||||
prefix = DEFAULT_PREFIX.format(schema=schema)
|
||||
suffix = DEFAULT_SUFFIX.format(
|
||||
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
||||
)
|
||||
|
||||
examples = examples or DEFAULT_EXAMPLES
|
||||
prefix = DEFAULT_PREFIX.format(schema=schema)
|
||||
suffix = DEFAULT_SUFFIX.format(
|
||||
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
||||
)
|
||||
output_parser = StructuredQueryOutputParser.from_components(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return FewShotPromptTemplate(
|
||||
examples=examples,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
examples=list(examples),
|
||||
example_prompt=example_prompt,
|
||||
input_variables=["query"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_query_constructor_chain(
|
||||
llm: BaseLanguageModel,
|
||||
document_contents: str,
|
||||
attribute_info: List[AttributeInfo],
|
||||
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||
enable_limit: bool = False,
|
||||
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""Load a query constructor chain.
|
||||
@ -141,25 +244,95 @@ def load_query_constructor_chain(
|
||||
Args:
|
||||
llm: BaseLanguageModel to use for the chain.
|
||||
document_contents: The contents of the document to be queried.
|
||||
attribute_info: A list of AttributeInfo objects describing
|
||||
the attributes of the document.
|
||||
attribute_info: Sequence of attributes in the document.
|
||||
examples: Optional list of examples to use for the chain.
|
||||
allowed_comparators: An optional list of allowed comparators.
|
||||
allowed_operators: An optional list of allowed operators.
|
||||
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||
Comparators.
|
||||
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
|
||||
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||
**kwargs:
|
||||
schema_prompt: Prompt for describing query schema. Should have string input
|
||||
variables allowed_comparators and allowed_operators.
|
||||
**kwargs: Arbitrary named params to pass to LLMChain.
|
||||
|
||||
Returns:
|
||||
A LLMChain that can be used to construct queries.
|
||||
"""
|
||||
prompt = _get_prompt(
|
||||
prompt = get_query_constructor_prompt(
|
||||
document_contents,
|
||||
attribute_info,
|
||||
examples=examples,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
enable_limit=enable_limit,
|
||||
schema_prompt=schema_prompt,
|
||||
)
|
||||
return LLMChain(
|
||||
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
|
||||
allowed_attributes = []
|
||||
for ainfo in attribute_info:
|
||||
allowed_attributes.append(
|
||||
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
|
||||
)
|
||||
output_parser = StructuredQueryOutputParser.from_components(
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
)
|
||||
# For backwards compatibility.
|
||||
prompt.output_parser = output_parser
|
||||
return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
|
||||
|
||||
|
||||
def load_query_constructor_runnable(
|
||||
llm: BaseLanguageModel,
|
||||
document_contents: str,
|
||||
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||
*,
|
||||
examples: Optional[Sequence] = None,
|
||||
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||
enable_limit: bool = False,
|
||||
schema_prompt: Optional[BasePromptTemplate] = None,
|
||||
fix_invalid: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
"""Load a query constructor runnable chain.
|
||||
|
||||
Args:
|
||||
llm: BaseLanguageModel to use for the chain.
|
||||
document_contents: The contents of the document to be queried.
|
||||
attribute_info: Sequence of attributes in the document.
|
||||
examples: Optional list of examples to use for the chain.
|
||||
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||
Comparators.
|
||||
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
|
||||
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||
schema_prompt: Prompt for describing query schema. Should have string input
|
||||
variables allowed_comparators and allowed_operators.
|
||||
fix_invalid: Whether to fix invalid filter directives by ignoring invalid
|
||||
operators, comparators and attributes.
|
||||
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
|
||||
|
||||
Returns:
|
||||
A Runnable that can be used to construct queries.
|
||||
"""
|
||||
prompt = get_query_constructor_prompt(
|
||||
document_contents,
|
||||
attribute_info,
|
||||
examples=examples,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
enable_limit=enable_limit,
|
||||
schema_prompt=schema_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
allowed_attributes = []
|
||||
for ainfo in attribute_info:
|
||||
allowed_attributes.append(
|
||||
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
|
||||
)
|
||||
output_parser = StructuredQueryOutputParser.from_components(
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
fix_invalid=fix_invalid,
|
||||
)
|
||||
return prompt | llm | output_parser
|
||||
|
@ -61,11 +61,13 @@ class QueryTransformer(Transformer):
|
||||
*args: Any,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_attributes: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.allowed_comparators = allowed_comparators
|
||||
self.allowed_operators = allowed_operators
|
||||
self.allowed_attributes = allowed_attributes
|
||||
|
||||
def program(self, *items: Any) -> tuple:
|
||||
return items
|
||||
@ -73,6 +75,11 @@ class QueryTransformer(Transformer):
|
||||
def func_call(self, func_name: Any, args: list) -> FilterDirective:
|
||||
func = self._match_func_name(str(func_name))
|
||||
if isinstance(func, Comparator):
|
||||
if self.allowed_attributes and args[0] not in self.allowed_attributes:
|
||||
raise ValueError(
|
||||
f"Received invalid attributes {args[0]}. Allowed attributes are "
|
||||
f"{self.allowed_attributes}"
|
||||
)
|
||||
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
||||
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||
return args[0]
|
||||
@ -134,6 +141,7 @@ class QueryTransformer(Transformer):
|
||||
def get_parser(
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
allowed_attributes: Optional[Sequence[str]] = None,
|
||||
) -> Lark:
|
||||
"""
|
||||
Returns a parser for the query language.
|
||||
@ -151,6 +159,8 @@ def get_parser(
|
||||
"Cannot import lark, please install it with 'pip install lark'."
|
||||
)
|
||||
transformer = QueryTransformer(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
)
|
||||
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")
|
||||
|
@ -3,36 +3,31 @@ from langchain.prompts import PromptTemplate
|
||||
|
||||
SONG_DATA_SOURCE = """\
|
||||
```json
|
||||
{
|
||||
{{
|
||||
"content": "Lyrics of a song",
|
||||
"attributes": {
|
||||
"artist": {
|
||||
"attributes": {{
|
||||
"artist": {{
|
||||
"type": "string",
|
||||
"description": "Name of the song artist"
|
||||
},
|
||||
"length": {
|
||||
}},
|
||||
"length": {{
|
||||
"type": "integer",
|
||||
"description": "Length of the song in seconds"
|
||||
},
|
||||
"genre": {
|
||||
}},
|
||||
"genre": {{
|
||||
"type": "string",
|
||||
"description": "The song genre, one of \"pop\", \"rock\" or \"rap\""
|
||||
}
|
||||
}
|
||||
}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
```\
|
||||
""".replace(
|
||||
"{", "{{"
|
||||
).replace(
|
||||
"}", "}}"
|
||||
)
|
||||
"""
|
||||
|
||||
FULL_ANSWER = """\
|
||||
```json
|
||||
{{
|
||||
"query": "teenager love",
|
||||
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
|
||||
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
||||
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
|
||||
}}
|
||||
```\
|
||||
"""
|
||||
@ -104,16 +99,24 @@ Structured Request:
|
||||
{structured_request}
|
||||
"""
|
||||
|
||||
EXAMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["i", "data_source", "user_query", "structured_request"],
|
||||
template=EXAMPLE_PROMPT_TEMPLATE,
|
||||
)
|
||||
EXAMPLE_PROMPT = PromptTemplate.from_template(EXAMPLE_PROMPT_TEMPLATE)
|
||||
|
||||
USER_SPECIFIED_EXAMPLE_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
<< Example {i}. >>
|
||||
User Query:
|
||||
{user_query}
|
||||
|
||||
Structured Request:
|
||||
```json
|
||||
{structured_request}
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
DEFAULT_SCHEMA = """\
|
||||
<< Structured Request Schema >>
|
||||
When responding use a markdown code snippet with a JSON object formatted in the \
|
||||
following schema:
|
||||
When responding use a markdown code snippet with a JSON object formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{{{
|
||||
@ -122,11 +125,9 @@ following schema:
|
||||
}}}}
|
||||
```
|
||||
|
||||
The query string should contain only text that is expected to match the contents of \
|
||||
documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||
|
||||
A logical condition statement is composed of one or more comparison and logical \
|
||||
operation statements.
|
||||
A logical condition statement is composed of one or more comparison and logical operation statements.
|
||||
|
||||
A comparison statement takes the form: `comp(attr, val)`:
|
||||
- `comp` ({allowed_comparators}): comparator
|
||||
@ -135,24 +136,20 @@ A comparison statement takes the form: `comp(attr, val)`:
|
||||
|
||||
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
||||
- `op` ({allowed_operators}): logical operator
|
||||
- `statement1`, `statement2`, ... (comparison statements or logical operation \
|
||||
statements): one or more statements to apply the operation to
|
||||
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
|
||||
|
||||
Make sure that you only use the comparators and logical operators listed above and \
|
||||
no others.
|
||||
Make sure that you only use the comparators and logical operators listed above and no others.
|
||||
Make sure that filters only refer to attributes that exist in the data source.
|
||||
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
||||
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
||||
Make sure that filters take into account the descriptions of attributes and only make \
|
||||
comparisons that are feasible given the type of data being stored.
|
||||
Make sure that filters are only used as needed. If there are no filters that should be \
|
||||
applied return "NO_FILTER" for the filter value.\
|
||||
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
|
||||
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\
|
||||
"""
|
||||
DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA)
|
||||
|
||||
SCHEMA_WITH_LIMIT = """\
|
||||
<< Structured Request Schema >>
|
||||
When responding use a markdown code snippet with a JSON object formatted in the \
|
||||
following schema:
|
||||
When responding use a markdown code snippet with a JSON object formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{{{
|
||||
@ -162,11 +159,9 @@ following schema:
|
||||
}}}}
|
||||
```
|
||||
|
||||
The query string should contain only text that is expected to match the contents of \
|
||||
documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||
|
||||
A logical condition statement is composed of one or more comparison and logical \
|
||||
operation statements.
|
||||
A logical condition statement is composed of one or more comparison and logical operation statements.
|
||||
|
||||
A comparison statement takes the form: `comp(attr, val)`:
|
||||
- `comp` ({allowed_comparators}): comparator
|
||||
@ -175,20 +170,17 @@ A comparison statement takes the form: `comp(attr, val)`:
|
||||
|
||||
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
||||
- `op` ({allowed_operators}): logical operator
|
||||
- `statement1`, `statement2`, ... (comparison statements or logical operation \
|
||||
statements): one or more statements to apply the operation to
|
||||
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
|
||||
|
||||
Make sure that you only use the comparators and logical operators listed above and \
|
||||
no others.
|
||||
Make sure that you only use the comparators and logical operators listed above and no others.
|
||||
Make sure that filters only refer to attributes that exist in the data source.
|
||||
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
|
||||
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
|
||||
Make sure that filters take into account the descriptions of attributes and only make \
|
||||
comparisons that are feasible given the type of data being stored.
|
||||
Make sure that filters are only used as needed. If there are no filters that should be \
|
||||
applied return "NO_FILTER" for the filter value.
|
||||
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
|
||||
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
|
||||
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
|
||||
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense.
|
||||
"""
|
||||
SCHEMA_WITH_LIMIT_PROMPT = PromptTemplate.from_template(SCHEMA_WITH_LIMIT)
|
||||
|
||||
DEFAULT_PREFIX = """\
|
||||
Your goal is to structure the user's query to match the request schema provided below.
|
||||
@ -196,6 +188,20 @@ Your goal is to structure the user's query to match the request schema provided
|
||||
{schema}\
|
||||
"""
|
||||
|
||||
PREFIX_WITH_DATA_SOURCE = (
|
||||
DEFAULT_PREFIX
|
||||
+ """
|
||||
|
||||
<< Data Source >>
|
||||
```json
|
||||
{{{{
|
||||
"content": "{content}",
|
||||
"attributes": {attributes}
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
DEFAULT_SUFFIX = """\
|
||||
<< Example {i}. >>
|
||||
Data Source:
|
||||
@ -211,3 +217,11 @@ User Query:
|
||||
|
||||
Structured Request:
|
||||
"""
|
||||
|
||||
SUFFIX_WITHOUT_DATA_SOURCE = """\
|
||||
<< Example {i}. >>
|
||||
User Query:
|
||||
{{query}}
|
||||
|
||||
Structured Request:
|
||||
"""
|
||||
|
@ -1,13 +1,12 @@
|
||||
"""Retriever that generates and executes structured queries over its own data source."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_runnable
|
||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
@ -27,6 +26,7 @@ from langchain.retrievers.self_query.vectara import VectaraTranslator
|
||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.vectorstores import (
|
||||
Chroma,
|
||||
DashVector,
|
||||
@ -86,8 +86,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""The underlying vector store from which documents will be retrieved."""
|
||||
llm_chain: LLMChain
|
||||
"""The LLMChain for generating the vector store queries."""
|
||||
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain")
|
||||
"""The query constructor chain for generating the vector store queries.
|
||||
|
||||
llm_chain is legacy name kept for backwards compatibility."""
|
||||
search_type: str = "similarity"
|
||||
"""The search type to perform on the vector store."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
@ -103,6 +105,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_translator(cls, values: Dict) -> Dict:
|
||||
@ -113,23 +116,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_structured_query(
|
||||
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
|
||||
) -> StructuredQuery:
|
||||
structured_query = cast(
|
||||
StructuredQuery,
|
||||
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
|
||||
)
|
||||
return structured_query
|
||||
|
||||
async def _aget_structured_query(
|
||||
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> StructuredQuery:
|
||||
structured_query = cast(
|
||||
StructuredQuery,
|
||||
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
|
||||
)
|
||||
return structured_query
|
||||
@property
|
||||
def llm_chain(self) -> Runnable:
|
||||
"""llm_chain is legacy name kept for backwards compatibility."""
|
||||
return self.query_constructor
|
||||
|
||||
def _prepare_query(
|
||||
self, query: str, structured_query: StructuredQuery
|
||||
@ -167,8 +157,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||
structured_query = self._get_structured_query(inputs, run_manager)
|
||||
structured_query = self.query_constructor.invoke(
|
||||
{"query": query}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
@ -186,8 +177,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||
structured_query = await self._aget_structured_query(inputs, run_manager)
|
||||
structured_query = await self.query_constructor.ainvoke(
|
||||
{"query": query}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
@ -200,7 +192,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
llm: BaseLanguageModel,
|
||||
vectorstore: VectorStore,
|
||||
document_contents: str,
|
||||
metadata_field_info: List[AttributeInfo],
|
||||
metadata_field_info: Sequence[Union[AttributeInfo, dict]],
|
||||
structured_query_translator: Optional[Visitor] = None,
|
||||
chain_kwargs: Optional[Dict] = None,
|
||||
enable_limit: bool = False,
|
||||
@ -219,7 +211,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
chain_kwargs[
|
||||
"allowed_operators"
|
||||
] = structured_query_translator.allowed_operators
|
||||
llm_chain = load_query_constructor_chain(
|
||||
query_constructor = load_query_constructor_runnable(
|
||||
llm,
|
||||
document_contents,
|
||||
metadata_field_info,
|
||||
@ -227,7 +219,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
**chain_kwargs,
|
||||
)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
query_constructor=query_constructor,
|
||||
vectorstore=vectorstore,
|
||||
use_original_query=use_original_query,
|
||||
structured_query_translator=structured_query_translator,
|
||||
|
Loading…
Reference in New Issue
Block a user