mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
sql: do not hard code the LIMIT clause in the table_info section (#1563)
Seeing a lot of issues in Discord in which the LLM is not using the correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT` for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc. I think this could be due to us specifying the LIMIT statement in the example rows portion of `table_info`. So the LLM is seeing the `LIMIT` statement used in the prompt. Since we can't specify each dialect's method here, I think it's fine to just replace the `SELECT... LIMIT 3;` statement with `3 rows from table_name table:`, and wrap everything in a block comment directly following the `CREATE` statement. The Rajkumar et al paper wrapped the example rows and `SELECT` statement in a block comment as well anyway. Thoughts @fpingham?
This commit is contained in:
@@ -34,9 +34,10 @@ def test_table_info() -> None:
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
/*
|
||||
3 rows from user table:
|
||||
user_id user_name
|
||||
/*
|
||||
|
||||
|
||||
CREATE TABLE company (
|
||||
@@ -44,9 +45,10 @@ def test_table_info() -> None:
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 3;
|
||||
/*
|
||||
3 rows from company table:
|
||||
company_id company_location
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
@@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None:
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 2;
|
||||
/*
|
||||
2 rows from company table:
|
||||
company_id company_location
|
||||
|
||||
*/
|
||||
|
||||
CREATE TABLE user (
|
||||
user_id INTEGER NOT NULL,
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 2;
|
||||
/*
|
||||
2 rows from user table:
|
||||
user_id user_name
|
||||
13 Harrison
|
||||
14 Chase
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(output.split()) == sorted(expected_output.split())
|
||||
|
@@ -54,9 +54,10 @@ def test_table_info() -> None:
|
||||
user_name VARCHAR NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
/*
|
||||
3 rows from user table:
|
||||
user_id user_name
|
||||
*/
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
|
Reference in New Issue
Block a user