mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
sql: do not hard code the LIMIT clause in the table_info section (#1563)
Seeing a lot of issues in Discord in which the LLM is not using the correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT` for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc. I think this could be due to us specifying the LIMIT statement in the example rows portion of `table_info`. So the LLM is seeing the `LIMIT` statement used in the prompt. Since we can't specify each dialect's method here, I think it's fine to just replace the `SELECT... LIMIT 3;` statement with `3 rows from table_name table:`, and wrap everything in a block comment directly following the `CREATE` statement. The Rajkumar et al paper wrapped the example rows and `SELECT` statement in a block comment as well anyway. Thoughts @fpingham?
This commit is contained in:
parent
9ee2713272
commit
0a1b1806e9
@ -377,18 +377,19 @@
|
||||
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
|
||||
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"SELECT * FROM 'Track' LIMIT 2;\n",
|
||||
"/*\n",
|
||||
"2 rows from Track table:\n",
|
||||
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
|
||||
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
|
||||
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n"
|
||||
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
|
||||
"*/\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/jon/projects/langchain/langchain/sql_database.py:121: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
|
||||
"/home/jon/projects/langchain/langchain/sql_database.py:135: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
|
||||
" sample_rows = connection.execute(command)\n"
|
||||
]
|
||||
}
|
||||
@ -467,12 +468,13 @@
|
||||
"\t\"Composer\" NVARCHAR(220),\n",
|
||||
"\tPRIMARY KEY (\"TrackId\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"SELECT * FROM 'Track' LIMIT 3;\n",
|
||||
"/*\n",
|
||||
"3 rows from Track table:\n",
|
||||
"TrackId\tName\tComposer\n",
|
||||
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
|
||||
"2\tBalls to the Wall\tNone\n",
|
||||
"3\tMy favorite song ever\tThe coolest composer of all time\"\"\"\n",
|
||||
"3\tMy favorite song ever\tThe coolest composer of all time\n",
|
||||
"*/\"\"\"\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@ -492,11 +494,12 @@
|
||||
"\t\"Name\" NVARCHAR(120), \n",
|
||||
"\tPRIMARY KEY (\"PlaylistId\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"SELECT * FROM 'Playlist' LIMIT 2;\n",
|
||||
"/*\n",
|
||||
"2 rows from Playlist table:\n",
|
||||
"PlaylistId\tName\n",
|
||||
"1\tMusic\n",
|
||||
"2\tMovies\n",
|
||||
"*/\n",
|
||||
"\n",
|
||||
"CREATE TABLE Track (\n",
|
||||
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
||||
@ -504,12 +507,13 @@
|
||||
"\t\"Composer\" NVARCHAR(220),\n",
|
||||
"\tPRIMARY KEY (\"TrackId\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"SELECT * FROM 'Track' LIMIT 3;\n",
|
||||
"/*\n",
|
||||
"3 rows from Track table:\n",
|
||||
"TrackId\tName\tComposer\n",
|
||||
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
|
||||
"2\tBalls to the Wall\tNone\n",
|
||||
"3\tMy favorite song ever\tThe coolest composer of all time\n"
|
||||
"3\tMy favorite song ever\tThe coolest composer of all time\n",
|
||||
"*/\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -126,12 +126,6 @@ class SQLDatabase:
|
||||
# build the select command
|
||||
command = select(table).limit(self._sample_rows_in_table_info)
|
||||
|
||||
# save the command in string format
|
||||
select_star = (
|
||||
f"SELECT * FROM '{table.name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
|
||||
# save the columns in string format
|
||||
columns_str = "\t".join([col.name for col in table.columns])
|
||||
|
||||
@ -152,16 +146,18 @@ class SQLDatabase:
|
||||
except ProgrammingError:
|
||||
sample_rows_str = ""
|
||||
|
||||
# build final info for table
|
||||
tables.append(
|
||||
create_table
|
||||
+ select_star
|
||||
+ ";\n"
|
||||
+ columns_str
|
||||
+ "\n"
|
||||
+ sample_rows_str
|
||||
table_info = (
|
||||
f"{create_table.rstrip()}\n"
|
||||
f"/*\n"
|
||||
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||
f"{columns_str}\n"
|
||||
f"{sample_rows_str}\n"
|
||||
f"*/"
|
||||
)
|
||||
|
||||
# build final info for table
|
||||
tables.append(table_info)
|
||||
|
||||
else:
|
||||
tables.append(create_table)
|
||||
|
||||
|
@ -34,9 +34,10 @@ def test_table_info() -> None:
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
/*
|
||||
3 rows from user table:
|
||||
user_id user_name
|
||||
/*
|
||||
|
||||
|
||||
CREATE TABLE company (
|
||||
@ -44,9 +45,10 @@ def test_table_info() -> None:
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 3;
|
||||
/*
|
||||
3 rows from company table:
|
||||
company_id company_location
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None:
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 2;
|
||||
/*
|
||||
2 rows from company table:
|
||||
company_id company_location
|
||||
|
||||
*/
|
||||
|
||||
CREATE TABLE user (
|
||||
user_id INTEGER NOT NULL,
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 2;
|
||||
/*
|
||||
2 rows from user table:
|
||||
user_id user_name
|
||||
13 Harrison
|
||||
14 Chase
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(output.split()) == sorted(expected_output.split())
|
||||
|
@ -54,9 +54,10 @@ def test_table_info() -> None:
|
||||
user_name VARCHAR NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
/*
|
||||
3 rows from user table:
|
||||
user_id user_name
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
|
Loading…
Reference in New Issue
Block a user