diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index c28103f8cc8..7a6c8180c89 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "a8fc8f23", "metadata": {}, "outputs": [], @@ -242,6 +242,74 @@ "db_chain.run(\"What are some example tracks by composer Johann Sebastian Bach?\")" ] }, + { + "cell_type": "markdown", + "id": "bcc5e936", + "metadata": {}, + "source": [ + "## Adding first row of each table\n", + "Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9a22ee47", + "metadata": {}, + "outputs": [], + "source": [ + "db = SQLDatabase.from_uri(\n", + " \"sqlite:///../../../../notebooks/Chinook.db\", \n", + " include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n", + " sample_row_in_table_info=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bcb7a489", + "metadata": {}, + "outputs": [], + "source": [ + "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "81e05d82", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", + "What are some example tracks by Bach? \n", + "SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n", + "\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_chain.run(\"What are some example tracks by Bach?\")" + ] + }, { "cell_type": "markdown", "id": "c12ae15a", @@ -319,14 +387,6 @@ "source": [ "chain.run(\"How many employees are also customers?\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2998b03", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -345,7 +405,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 2d5a8405936..27fc98a668e 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -16,6 +16,7 @@ class SQLDatabase: schema: Optional[str] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, + sample_row_in_table_info: bool = False, ): """Create engine from database URI.""" self._engine = engine @@ -39,6 +40,7 @@ class SQLDatabase: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) + self._sample_row_in_table_info = sample_row_in_table_info @classmethod def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: @@ -69,14 +71,28 @@ class SQLDatabase: if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names + template = "Table '{table_name}' has columns: {columns}." + tables = [] for table_name in all_table_names: + columns = [] for column in self._inspector.get_columns(table_name, schema=self._schema): columns.append(f"{column['name']} ({str(column['type'])})") column_str = ", ".join(columns) table_str = template.format(table_name=table_name, columns=column_str) + + if self._sample_row_in_table_info: + row_template = ( + " Here is an example row for this table" + " (long strings are truncated): {sample_row}." + ) + sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1") + if len(eval(sample_row)) > 0: + sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]]) + table_str += row_template.format(sample_row=sample_row) + tables.append(table_str) return "\n".join(tables) diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index 4c5d208f95d..d735d0555e8 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -35,6 +35,27 @@ def test_table_info() -> None: assert sorted(output.split("\n")) == sorted(expected_output) +def test_table_info_w_sample_row() -> None: + """Test that table info is constructed properly.""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + stmt = insert(user).values(user_id=13, user_name="Harrison") + with engine.begin() as conn: + conn.execute(stmt) + + db = SQLDatabase(engine, sample_row_in_table_info=True) + + output = db.table_info + expected_output = ( + "Table 'company' has columns: company_id (INTEGER), " + "company_location (VARCHAR).\n" + "Table 'user' has columns: user_id (INTEGER), " + "user_name (VARCHAR(16)). Here is an example row " + "for this table (long strings are truncated): 13 Harrison." + ) + assert sorted(output.split("\n")) == sorted(expected_output.split("\n")) + + def test_sql_database_run() -> None: """Test that commands can be run successfully and returned in correct format.""" engine = create_engine("sqlite:///:memory:")