community: Add FewShotSQLTool (#28232)

The `FewShotSQLTool` gets some SQL query examples from a
`BaseExampleSelector` for a given question.
This is useful to provide [few-shot
examples](https://python.langchain.com/docs/how_to/sql_prompting/#few-shot-examples)
capability to an SQL agent.

Example usage:
```python
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX

embeddings = OpenAIEmbeddings()

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    AstraDB,
    k=5,
    input_keys=["input"],
    collection_name="lc_few_shots",
    token=ASTRA_DB_APPLICATION_TOKEN,
    api_endpoint=ASTRA_DB_API_ENDPOINT,
)

few_shot_sql_tool = FewShotSQLTool(
    example_selector=example_selector,
    description="Input to this tool is the input question, output is a few SQL query examples related to the input question. Always use this tool before checking the query with sql_db_query_checker!"
)

agent = create_sql_agent(
    llm=llm, 
    db=db, 
    prefix=SQL_PREFIX + "\nYou MUST get some example queries before creating the query.", 
    extra_tools=[few_shot_sql_tool]
)

result = agent.invoke({"input": "How many artists are there?"})
```

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Christophe Bornet 2024-12-16 16:37:21 +01:00 committed by GitHub
parent 8d746086ab
commit 6ddd5dbb1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from langchain_community.tools.few_shot.tool import FewShotSQLTool
__all__ = ["FewShotSQLTool"]

View File

@ -0,0 +1,46 @@
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field
class _FewShotToolInput(BaseModel):
question: str = Field(
..., description="The question for which we want example SQL queries."
)
class FewShotSQLTool(BaseTool): # type: ignore[override]
"""Tool to get example SQL queries related to an input question."""
name: str = "few_shot_sql"
description: str = "Tool to get example SQL queries related to an input question."
args_schema: Type[BaseModel] = _FewShotToolInput
example_selector: BaseExampleSelector = Field(exclude=True)
example_input_key: str = "input"
example_query_key: str = "query"
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _run(
self,
question: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
example_prompt = PromptTemplate.from_template(
f"User input: {self.example_input_key}\nSQL query: {self.example_query_key}"
)
prompt = FewShotPromptTemplate(
example_prompt=example_prompt,
example_selector=self.example_selector,
suffix="",
input_variables=[self.example_input_key],
)
return prompt.format(**{self.example_input_key: question})

View File

@ -0,0 +1,63 @@
from typing import Type
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_tests.integration_tests import ToolsIntegrationTests
from langchain_tests.unit_tests import ToolsUnitTests
from langchain_community.tools.few_shot.tool import FewShotSQLTool
EXAMPLES = [
{
"input": "Number of rows in artist table",
"output": "select count(*) from Artist",
},
{
"input": "Number of rows in album table",
"output": "select count(*) from Album",
},
]
EMBEDDINGS = DeterministicFakeEmbedding(size=10)
EXAMPLE_SELECTOR = SemanticSimilarityExampleSelector.from_examples(
EXAMPLES,
EMBEDDINGS,
InMemoryVectorStore,
k=5,
input_keys=["input"],
)
class TestFewShotSQLToolUnit(ToolsUnitTests):
@property
def tool_constructor(self) -> Type[FewShotSQLTool]:
return FewShotSQLTool
@property
def tool_constructor_params(self) -> dict:
return {
"example_selector": EXAMPLE_SELECTOR,
"description": "Use this tool to select examples.",
}
@property
def tool_invoke_params_example(self) -> dict:
return {"question": "How many rows are in the customer table?"}
class TestFewShotSQLToolIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> Type[FewShotSQLTool]:
return FewShotSQLTool
@property
def tool_constructor_params(self) -> dict:
return {
"example_selector": EXAMPLE_SELECTOR,
"description": "Use this tool to select examples.",
}
@property
def tool_invoke_params_example(self) -> dict:
return {"question": "How many rows are in the customer table?"}