notebook fmt (#12498)

This commit is contained in:
Bagatur
2023-10-29 15:50:09 -07:00
committed by GitHub
parent 56cc5b847c
commit 2424fff3f1
342 changed files with 8261 additions and 6796 deletions

View File

@@ -1,3 +1,3 @@
from sql_llama2.chain import chain
__all__ = ["chain"]
__all__ = ["chain"]

View File

@@ -20,6 +20,7 @@ rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
def get_schema(_):
return db.get_table_info()
@@ -27,6 +28,7 @@ def get_schema(_):
def run_query(query):
return db.run(query)
template_query = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
@@ -64,10 +66,12 @@ prompt_response = ChatPromptTemplate.from_messages(
]
)
# Supply the input types to the prompt
# Supply the input types to the prompt
class InputType(BaseModel):
question: str
chain = (
RunnablePassthrough.assign(query=sql_response).with_types(input_type=InputType)
| RunnablePassthrough.assign(