mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
community: Add FewShotSQLTool (#28232)
The `FewShotSQLTool` gets some SQL query examples from a `BaseExampleSelector` for a given question. This is useful to provide [few-shot examples](https://python.langchain.com/docs/how_to/sql_prompting/#few-shot-examples) capability to an SQL agent. Example usage: ```python from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX embeddings = OpenAIEmbeddings() example_selector = SemanticSimilarityExampleSelector.from_examples( examples, embeddings, AstraDB, k=5, input_keys=["input"], collection_name="lc_few_shots", token=ASTRA_DB_APPLICATION_TOKEN, api_endpoint=ASTRA_DB_API_ENDPOINT, ) few_shot_sql_tool = FewShotSQLTool( example_selector=example_selector, description="Input to this tool is the input question, output is a few SQL query examples related to the input question. Always use this tool before checking the query with sql_db_query_checker!" ) agent = create_sql_agent( llm=llm, db=db, prefix=SQL_PREFIX + "\nYou MUST get some example queries before creating the query.", extra_tools=[few_shot_sql_tool] ) result = agent.invoke({"input": "How many artists are there?"}) ``` --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8d746086ab
commit
6ddd5dbb1e
@ -0,0 +1,3 @@
|
||||
from langchain_community.tools.few_shot.tool import FewShotSQLTool
|
||||
|
||||
__all__ = ["FewShotSQLTool"]
|
46
libs/community/langchain_community/tools/few_shot/tool.py
Normal file
46
libs/community/langchain_community/tools/few_shot/tool.py
Normal file
@ -0,0 +1,46 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _FewShotToolInput(BaseModel):
|
||||
question: str = Field(
|
||||
..., description="The question for which we want example SQL queries."
|
||||
)
|
||||
|
||||
|
||||
class FewShotSQLTool(BaseTool): # type: ignore[override]
|
||||
"""Tool to get example SQL queries related to an input question."""
|
||||
|
||||
name: str = "few_shot_sql"
|
||||
description: str = "Tool to get example SQL queries related to an input question."
|
||||
args_schema: Type[BaseModel] = _FewShotToolInput
|
||||
|
||||
example_selector: BaseExampleSelector = Field(exclude=True)
|
||||
example_input_key: str = "input"
|
||||
example_query_key: str = "query"
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
question: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
example_prompt = PromptTemplate.from_template(
|
||||
f"User input: {self.example_input_key}\nSQL query: {self.example_query_key}"
|
||||
)
|
||||
prompt = FewShotPromptTemplate(
|
||||
example_prompt=example_prompt,
|
||||
example_selector=self.example_selector,
|
||||
suffix="",
|
||||
input_variables=[self.example_input_key],
|
||||
)
|
||||
return prompt.format(**{self.example_input_key: question})
|
63
libs/community/tests/unit_tests/tools/test_few_shot.py
Normal file
63
libs/community/tests/unit_tests/tools/test_few_shot.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from langchain_tests.integration_tests import ToolsIntegrationTests
|
||||
from langchain_tests.unit_tests import ToolsUnitTests
|
||||
|
||||
from langchain_community.tools.few_shot.tool import FewShotSQLTool
|
||||
|
||||
EXAMPLES = [
|
||||
{
|
||||
"input": "Number of rows in artist table",
|
||||
"output": "select count(*) from Artist",
|
||||
},
|
||||
{
|
||||
"input": "Number of rows in album table",
|
||||
"output": "select count(*) from Album",
|
||||
},
|
||||
]
|
||||
EMBEDDINGS = DeterministicFakeEmbedding(size=10)
|
||||
|
||||
EXAMPLE_SELECTOR = SemanticSimilarityExampleSelector.from_examples(
|
||||
EXAMPLES,
|
||||
EMBEDDINGS,
|
||||
InMemoryVectorStore,
|
||||
k=5,
|
||||
input_keys=["input"],
|
||||
)
|
||||
|
||||
|
||||
class TestFewShotSQLToolUnit(ToolsUnitTests):
|
||||
@property
|
||||
def tool_constructor(self) -> Type[FewShotSQLTool]:
|
||||
return FewShotSQLTool
|
||||
|
||||
@property
|
||||
def tool_constructor_params(self) -> dict:
|
||||
return {
|
||||
"example_selector": EXAMPLE_SELECTOR,
|
||||
"description": "Use this tool to select examples.",
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_invoke_params_example(self) -> dict:
|
||||
return {"question": "How many rows are in the customer table?"}
|
||||
|
||||
|
||||
class TestFewShotSQLToolIntegration(ToolsIntegrationTests):
|
||||
@property
|
||||
def tool_constructor(self) -> Type[FewShotSQLTool]:
|
||||
return FewShotSQLTool
|
||||
|
||||
@property
|
||||
def tool_constructor_params(self) -> dict:
|
||||
return {
|
||||
"example_selector": EXAMPLE_SELECTOR,
|
||||
"description": "Use this tool to select examples.",
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_invoke_params_example(self) -> dict:
|
||||
return {"question": "How many rows are in the customer table?"}
|
Loading…
Reference in New Issue
Block a user