diff --git a/libs/community/langchain_community/tools/few_shot/__init__.py b/libs/community/langchain_community/tools/few_shot/__init__.py new file mode 100644 index 00000000000..e19f14575f5 --- /dev/null +++ b/libs/community/langchain_community/tools/few_shot/__init__.py @@ -0,0 +1,3 @@ +from langchain_community.tools.few_shot.tool import FewShotSQLTool + +__all__ = ["FewShotSQLTool"] diff --git a/libs/community/langchain_community/tools/few_shot/tool.py b/libs/community/langchain_community/tools/few_shot/tool.py new file mode 100644 index 00000000000..2387ca4edb7 --- /dev/null +++ b/libs/community/langchain_community/tools/few_shot/tool.py @@ -0,0 +1,46 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.example_selectors import BaseExampleSelector +from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate +from langchain_core.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field + + +class _FewShotToolInput(BaseModel): + question: str = Field( + ..., description="The question for which we want example SQL queries." + ) + + +class FewShotSQLTool(BaseTool): # type: ignore[override] + """Tool to get example SQL queries related to an input question.""" + + name: str = "few_shot_sql" + description: str = "Tool to get example SQL queries related to an input question." + args_schema: Type[BaseModel] = _FewShotToolInput + + example_selector: BaseExampleSelector = Field(exclude=True) + example_input_key: str = "input" + example_query_key: str = "query" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def _run( + self, + question: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + example_prompt = PromptTemplate.from_template( + f"User input: {self.example_input_key}\nSQL query: {self.example_query_key}" + ) + prompt = FewShotPromptTemplate( + example_prompt=example_prompt, + example_selector=self.example_selector, + suffix="", + input_variables=[self.example_input_key], + ) + return prompt.format(**{self.example_input_key: question}) diff --git a/libs/community/tests/unit_tests/tools/test_few_shot.py b/libs/community/tests/unit_tests/tools/test_few_shot.py new file mode 100644 index 00000000000..4a59fc6f54c --- /dev/null +++ b/libs/community/tests/unit_tests/tools/test_few_shot.py @@ -0,0 +1,63 @@ +from typing import Type + +from langchain_core.embeddings.fake import DeterministicFakeEmbedding +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_core.vectorstores import InMemoryVectorStore +from langchain_tests.integration_tests import ToolsIntegrationTests +from langchain_tests.unit_tests import ToolsUnitTests + +from langchain_community.tools.few_shot.tool import FewShotSQLTool + +EXAMPLES = [ + { + "input": "Number of rows in artist table", + "output": "select count(*) from Artist", + }, + { + "input": "Number of rows in album table", + "output": "select count(*) from Album", + }, +] +EMBEDDINGS = DeterministicFakeEmbedding(size=10) + +EXAMPLE_SELECTOR = SemanticSimilarityExampleSelector.from_examples( + EXAMPLES, + EMBEDDINGS, + InMemoryVectorStore, + k=5, + input_keys=["input"], +) + + +class TestFewShotSQLToolUnit(ToolsUnitTests): + @property + def tool_constructor(self) -> Type[FewShotSQLTool]: + return FewShotSQLTool + + @property + def tool_constructor_params(self) -> dict: + return { + "example_selector": EXAMPLE_SELECTOR, + "description": "Use this tool to select examples.", + } + + @property + def tool_invoke_params_example(self) -> dict: + return {"question": "How many rows are in the customer table?"} + + +class TestFewShotSQLToolIntegration(ToolsIntegrationTests): + @property + def tool_constructor(self) -> Type[FewShotSQLTool]: + return FewShotSQLTool + + @property + def tool_constructor_params(self) -> dict: + return { + "example_selector": EXAMPLE_SELECTOR, + "description": "Use this tool to select examples.", + } + + @property + def tool_invoke_params_example(self) -> dict: + return {"question": "How many rows are in the customer table?"}