sql: do not hard code the LIMIT clause in the table_info section (#1563)

Seeing a lot of issues in Discord in which the LLM is not using the
correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT`
for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc.
I think this could be due to us specifying the LIMIT statement in the
example rows portion of `table_info`. So the LLM is seeing the `LIMIT`
statement used in the prompt.
Since we can't specify each dialect's method here, I think it's fine to
just replace the `SELECT... LIMIT 3;` statement with `3 rows from
table_name table:`, and wrap everything in a block comment directly
following the `CREATE` statement. The Rajkumar et al paper wrapped the
example rows and `SELECT` statement in a block comment as well anyway.
Thoughts @fpingham?
This commit is contained in:
Jon Luo
2023-03-14 02:08:27 -04:00
committed by GitHub
parent 9ee2713272
commit 0a1b1806e9
4 changed files with 41 additions and 37 deletions

View File

@@ -34,9 +34,10 @@ def test_table_info() -> None:
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3;
/*
3 rows from user table:
user_id user_name
/*
CREATE TABLE company (
@@ -44,9 +45,10 @@ def test_table_info() -> None:
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3;
/*
3 rows from company table:
company_id company_location
*/
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
@@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None:
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2;
/*
2 rows from company table:
company_id company_location
*/
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 2;
/*
2 rows from user table:
user_id user_name
13 Harrison
14 Chase
*/
"""
assert sorted(output.split()) == sorted(expected_output.split())

View File

@@ -54,9 +54,10 @@ def test_table_info() -> None:
user_name VARCHAR NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3;
/*
3 rows from user table:
user_id user_name
*/
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))