From 6ddd5dbb1e9e665b64131624003c0abffbdaf830 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 16 Dec 2024 16:37:21 +0100 Subject: [PATCH] community: Add FewShotSQLTool (#28232) The `FewShotSQLTool` gets some SQL query examples from a `BaseExampleSelector` for a given question. This is useful to provide [few-shot examples](https://python.langchain.com/docs/how_to/sql_prompting/#few-shot-examples) capability to an SQL agent. Example usage: ```python from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX embeddings = OpenAIEmbeddings() example_selector = SemanticSimilarityExampleSelector.from_examples( examples, embeddings, AstraDB, k=5, input_keys=["input"], collection_name="lc_few_shots", token=ASTRA_DB_APPLICATION_TOKEN, api_endpoint=ASTRA_DB_API_ENDPOINT, ) few_shot_sql_tool = FewShotSQLTool( example_selector=example_selector, description="Input to this tool is the input question, output is a few SQL query examples related to the input question. Always use this tool before checking the query with sql_db_query_checker!" ) agent = create_sql_agent( llm=llm, db=db, prefix=SQL_PREFIX + "\nYou MUST get some example queries before creating the query.", extra_tools=[few_shot_sql_tool] ) result = agent.invoke({"input": "How many artists are there?"}) ``` --------- Co-authored-by: Chester Curme --- .../tools/few_shot/__init__.py | 3 + .../tools/few_shot/tool.py | 46 ++++++++++++++ .../tests/unit_tests/tools/test_few_shot.py | 63 +++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 libs/community/langchain_community/tools/few_shot/__init__.py create mode 100644 libs/community/langchain_community/tools/few_shot/tool.py create mode 100644 libs/community/tests/unit_tests/tools/test_few_shot.py diff --git a/libs/community/langchain_community/tools/few_shot/__init__.py b/libs/community/langchain_community/tools/few_shot/__init__.py new file mode 100644 index 00000000000..e19f14575f5 --- /dev/null +++ b/libs/community/langchain_community/tools/few_shot/__init__.py @@ -0,0 +1,3 @@ +from langchain_community.tools.few_shot.tool import FewShotSQLTool + +__all__ = ["FewShotSQLTool"] diff --git a/libs/community/langchain_community/tools/few_shot/tool.py b/libs/community/langchain_community/tools/few_shot/tool.py new file mode 100644 index 00000000000..2387ca4edb7 --- /dev/null +++ b/libs/community/langchain_community/tools/few_shot/tool.py @@ -0,0 +1,46 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.example_selectors import BaseExampleSelector +from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate +from langchain_core.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field + + +class _FewShotToolInput(BaseModel): + question: str = Field( + ..., description="The question for which we want example SQL queries." + ) + + +class FewShotSQLTool(BaseTool): # type: ignore[override] + """Tool to get example SQL queries related to an input question.""" + + name: str = "few_shot_sql" + description: str = "Tool to get example SQL queries related to an input question." + args_schema: Type[BaseModel] = _FewShotToolInput + + example_selector: BaseExampleSelector = Field(exclude=True) + example_input_key: str = "input" + example_query_key: str = "query" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def _run( + self, + question: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + example_prompt = PromptTemplate.from_template( + f"User input: {self.example_input_key}\nSQL query: {self.example_query_key}" + ) + prompt = FewShotPromptTemplate( + example_prompt=example_prompt, + example_selector=self.example_selector, + suffix="", + input_variables=[self.example_input_key], + ) + return prompt.format(**{self.example_input_key: question}) diff --git a/libs/community/tests/unit_tests/tools/test_few_shot.py b/libs/community/tests/unit_tests/tools/test_few_shot.py new file mode 100644 index 00000000000..4a59fc6f54c --- /dev/null +++ b/libs/community/tests/unit_tests/tools/test_few_shot.py @@ -0,0 +1,63 @@ +from typing import Type + +from langchain_core.embeddings.fake import DeterministicFakeEmbedding +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_core.vectorstores import InMemoryVectorStore +from langchain_tests.integration_tests import ToolsIntegrationTests +from langchain_tests.unit_tests import ToolsUnitTests + +from langchain_community.tools.few_shot.tool import FewShotSQLTool + +EXAMPLES = [ + { + "input": "Number of rows in artist table", + "output": "select count(*) from Artist", + }, + { + "input": "Number of rows in album table", + "output": "select count(*) from Album", + }, +] +EMBEDDINGS = DeterministicFakeEmbedding(size=10) + +EXAMPLE_SELECTOR = SemanticSimilarityExampleSelector.from_examples( + EXAMPLES, + EMBEDDINGS, + InMemoryVectorStore, + k=5, + input_keys=["input"], +) + + +class TestFewShotSQLToolUnit(ToolsUnitTests): + @property + def tool_constructor(self) -> Type[FewShotSQLTool]: + return FewShotSQLTool + + @property + def tool_constructor_params(self) -> dict: + return { + "example_selector": EXAMPLE_SELECTOR, + "description": "Use this tool to select examples.", + } + + @property + def tool_invoke_params_example(self) -> dict: + return {"question": "How many rows are in the customer table?"} + + +class TestFewShotSQLToolIntegration(ToolsIntegrationTests): + @property + def tool_constructor(self) -> Type[FewShotSQLTool]: + return FewShotSQLTool + + @property + def tool_constructor_params(self) -> dict: + return { + "example_selector": EXAMPLE_SELECTOR, + "description": "Use this tool to select examples.", + } + + @property + def tool_invoke_params_example(self) -> dict: + return {"question": "How many rows are in the customer table?"}