mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
more complex sql chain (#619)
add a more complex sql chain that first subsets the necessary tables
This commit is contained in:
parent
49b3d6c78c
commit
1c71fadfdc
@ -179,10 +179,80 @@
|
||||
"db_chain.run(\"How many employees are there in the foobar table?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c12ae15a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SQLDatabaseSequentialChain\n",
|
||||
"\n",
|
||||
"Chain for querying SQL database that is a sequential chain.\n",
|
||||
"\n",
|
||||
"The chain is as follows:\n",
|
||||
"\n",
|
||||
" 1. Based on the query, determine which tables to use.\n",
|
||||
" 2. Based on those tables, call the normal SQL database chain.\n",
|
||||
"\n",
|
||||
"This is useful in cases where the number of tables in the database is large."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e59a4740",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.sql_database.base import SQLDatabaseSequentialChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "58bb49b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "95017b1a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n",
|
||||
"Table names to use:\n",
|
||||
"\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' 0 employees are also customers.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"How many employees are also customers?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e59a4740",
|
||||
"id": "b2998b03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -47,7 +47,10 @@ class SequentialChain(Chain, BaseModel):
|
||||
for chain in chains:
|
||||
missing_vars = set(chain.input_keys).difference(known_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required input keys: {missing_vars}")
|
||||
raise ValueError(
|
||||
f"Missing required input keys: {missing_vars}, "
|
||||
f"only had {known_variables}"
|
||||
)
|
||||
overlapping_keys = known_variables.intersection(chain.output_keys)
|
||||
if overlapping_keys:
|
||||
raise ValueError(
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Chain for interacting with SQL Database."""
|
||||
from typing import Dict, List
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
@ -53,15 +55,18 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text(input_text)
|
||||
# If not present, then defaults to None which is all tables.
|
||||
table_names_to_use = inputs.get("table_names_to_use")
|
||||
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
||||
llm_inputs = {
|
||||
"input": input_text,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_info,
|
||||
"table_info": table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
@ -78,3 +83,68 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text(final_result, color="green")
|
||||
return {self.output_key: final_result}
|
||||
|
||||
|
||||
class SQLDatabaseSequentialChain(Chain, BaseModel):
|
||||
"""Chain for querying SQL database that is a sequential chain.
|
||||
|
||||
The chain is as follows:
|
||||
1. Based on the query, determine which tables to use.
|
||||
2. Based on those tables, call the normal SQL database chain.
|
||||
|
||||
This is useful in cases where the number of tables in the database is large.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
database: SQLDatabase,
|
||||
query_prompt: BasePromptTemplate = PROMPT,
|
||||
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> SQLDatabaseSequentialChain:
|
||||
"""Load the necessary chains."""
|
||||
sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt)
|
||||
decider_chain = LLMChain(
|
||||
llm=llm, prompt=decider_prompt, output_key="table_names"
|
||||
)
|
||||
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
|
||||
|
||||
decider_chain: LLMChain
|
||||
sql_chain: SQLDatabaseChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
_table_names = self.sql_chain.database.get_table_names()
|
||||
table_names = ", ".join(_table_names)
|
||||
llm_inputs = {
|
||||
"query": inputs[self.input_key],
|
||||
"table_names": table_names,
|
||||
}
|
||||
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text("Table names to use:", end="\n")
|
||||
self.callback_manager.on_text(str(table_names_to_use), color="yellow")
|
||||
new_inputs = {
|
||||
self.sql_chain.input_key: inputs[self.input_key],
|
||||
"table_names_to_use": table_names_to_use,
|
||||
}
|
||||
return self.sql_chain(new_inputs, return_only_outputs=True)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.base import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
@ -17,3 +18,16 @@ Question: {input}"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
|
||||
)
|
||||
|
||||
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be neccessary to answer this question.
|
||||
|
||||
Question: {query}
|
||||
|
||||
Table Names: {table_names}
|
||||
|
||||
Relevant Table Names:"""
|
||||
DECIDER_PROMPT = PromptTemplate(
|
||||
input_variables=["query", "table_names"],
|
||||
template=_DECIDER_TEMPLATE,
|
||||
output_parser=CommaSeparatedListOutputParser(),
|
||||
)
|
||||
|
@ -64,6 +64,14 @@ class ListOutputParser(BaseOutputParser):
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse out comma separated lists."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text.strip().split(", ")
|
||||
|
||||
|
||||
class RegexParser(BaseOutputParser, BaseModel):
|
||||
"""Class to parse the output into a dictionary."""
|
||||
|
||||
|
@ -50,7 +50,8 @@ class SQLDatabase:
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
def _get_table_names(self) -> Iterable[str]:
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
return self._include_tables
|
||||
return set(self._all_tables) - set(self._ignore_tables)
|
||||
@ -58,9 +59,19 @@ class SQLDatabase:
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Information about all tables in the database."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables."""
|
||||
all_table_names = self.get_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
template = "Table '{table_name}' has columns: {columns}."
|
||||
tables = []
|
||||
for table_name in self._get_table_names():
|
||||
for table_name in all_table_names:
|
||||
columns = []
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||
|
Loading…
Reference in New Issue
Block a user