From 6be5f4e4c4ae6c9f9f80a9c488b0f5284ca53807 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 17 Jan 2023 22:32:28 -0800 Subject: [PATCH] Harrison/sql db chain (#641) Co-authored-by: Bruno Bornsztein --- docs/modules/chains/examples/sqlite.ipynb | 14 +++++++++++--- langchain/chains/__init__.py | 6 +++++- langchain/chains/sql_database/base.py | 4 +++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index 9b0b8a7b234..c28103f8cc8 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -266,7 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.chains.sql_database.base import SQLDatabaseSequentialChain" + "from langchain.chains import SQLDatabaseSequentialChain" ] }, { @@ -293,14 +293,22 @@ "\n", "\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n", "Table names to use:\n", - "\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n", + "\u001b[33;1m\u001b[1;3m['Customer', 'Employee']\u001b[0m\n", + "\n", + "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", + "How many employees are also customers? \n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Customer c INNER JOIN Employee e ON c.SupportRepId = e.EmployeeId;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(59,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m There are 59 employees who are also customers.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' 0 employees are also customers.'" + "' There are 59 employees who are also customers.'" ] }, "execution_count": 5, diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 60fdf184183..6879e4850b7 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -12,7 +12,10 @@ from langchain.chains.pal.base import PALChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain -from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.chains.sql_database.base import ( + SQLDatabaseChain, + SQLDatabaseSequentialChain, +) from langchain.chains.transform import TransformChain from langchain.chains.vector_db_qa.base import VectorDBQA @@ -35,4 +38,5 @@ __all__ = [ "TransformChain", "MapReduceChain", "OpenAIModerationChain", + "SQLDatabaseSequentialChain", ] diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 30309ed1c02..41ef484b4d0 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -109,7 +109,9 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): **kwargs: Any, ) -> SQLDatabaseSequentialChain: """Load the necessary chains.""" - sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt) + sql_chain = SQLDatabaseChain( + llm=llm, database=database, prompt=query_prompt, **kwargs + ) decider_chain = LLMChain( llm=llm, prompt=decider_prompt, output_key="table_names" )