From 450c458f8f07f1a1493a13a7b29f17b84820f90d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Carlos=20Jos=C3=A9=20Camacho?= Date: Sat, 13 Apr 2024 17:27:16 -0600 Subject: [PATCH] community[minor]: Add Datahareld tool (#19680) **Description:** Integrate [dataherald](https://www.dataherald.com) tool, It is a natural language-to-SQL tool. **Dependencies:** Install dataherald sdk to use it, ``` pip install dataherald ``` --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur Co-authored-by: Christophe Bornet --- .../integrations/providers/dataherald.mdx | 64 ++++++++++ docs/docs/integrations/tools/dataherald.ipynb | 117 ++++++++++++++++++ .../langchain_community/tools/__init__.py | 1 + .../tools/dataherald/__init__.py | 8 ++ .../tools/dataherald/tool.py | 36 ++++++ .../langchain_community/utilities/__init__.py | 1 + .../utilities/dataherald.py | 67 ++++++++++ .../utilities/test_dataherald_api.py | 9 ++ .../tests/unit_tests/tools/test_imports.py | 1 + .../tests/unit_tests/tools/test_public_api.py | 1 + .../unit_tests/utilities/test_imports.py | 1 + 11 files changed, 306 insertions(+) create mode 100644 docs/docs/integrations/providers/dataherald.mdx create mode 100644 docs/docs/integrations/tools/dataherald.ipynb create mode 100644 libs/community/langchain_community/tools/dataherald/__init__.py create mode 100644 libs/community/langchain_community/tools/dataherald/tool.py create mode 100644 libs/community/langchain_community/utilities/dataherald.py create mode 100644 libs/community/tests/integration_tests/utilities/test_dataherald_api.py diff --git a/docs/docs/integrations/providers/dataherald.mdx b/docs/docs/integrations/providers/dataherald.mdx new file mode 100644 index 00000000000..d7e11be48fb --- /dev/null +++ b/docs/docs/integrations/providers/dataherald.mdx @@ -0,0 +1,64 @@ +# Dataherald + +>[Dataherald](https://www.dataherald.com) is a natural language-to-SQL. + +This page covers how to use the `Dataherald API` within LangChain. + +## Installation and Setup +- Install requirements with +```bash +pip install dataherald +``` +- Go to dataherald and sign up [here](https://www.dataherald.com) +- Create an app and get your `API KEY` +- Set your `API KEY` as an environment variable `DATAHERALD_API_KEY` + + +## Wrappers + +### Utility + +There exists a DataheraldAPIWrapper utility which wraps this API. To import this utility: + +```python +from langchain_community.utilities.dataherald import DataheraldAPIWrapper +``` + +For a more detailed walkthrough of this wrapper, see [this notebook](/docs/integrations/tools/dataherald). + +### Tool + +You can use the tool in an agent like this: +```python +from langchain_community.utilities.dataherald import DataheraldAPIWrapper +from langchain_community.tools.dataherald.tool import DataheraldTextToSQL +from langchain_openai import ChatOpenAI +from langchain import hub +from langchain.agents import AgentExecutor, create_react_agent, load_tools + +api_wrapper = DataheraldAPIWrapper(db_connection_id="") +tool = DataheraldTextToSQL(api_wrapper=api_wrapper) +llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) +prompt = hub.pull("hwchase17/react") +agent = create_react_agent(llm, tools, prompt) +agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) +agent_executor.invoke({"input":"Return the sql for this question: How many employees are in the company?"}) +``` + +Output +```shell +> Entering new AgentExecutor chain... +I need to use a tool that can convert this question into SQL. +Action: dataherald +Action Input: How many employees are in the company?Answer: SELECT + COUNT(*) FROM employeesI now know the final answer +Final Answer: SELECT + COUNT(*) +FROM + employees + +> Finished chain. +{'input': 'Return the sql for this question: How many employees are in the company?', 'output': "SELECT \n COUNT(*)\nFROM \n employees"} +``` + +For more information on tools, see [this page](/docs/modules/tools/). diff --git a/docs/docs/integrations/tools/dataherald.ipynb b/docs/docs/integrations/tools/dataherald.ipynb new file mode 100644 index 00000000000..bfb9bf35b2a --- /dev/null +++ b/docs/docs/integrations/tools/dataherald.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "245a954a", + "metadata": {}, + "source": [ + "# Dataherald\n", + "\n", + "This notebook goes over how to use the dataherald component.\n", + "\n", + "First, you need to set up your Dataherald account and get your API KEY:\n", + "\n", + "1. Go to dataherald and sign up [here](https://www.dataherald.com/)\n", + "2. Once you are logged in your Admin Console, create an API KEY\n", + "3. pip install dataherald\n", + "\n", + "Then we will need to set some environment variables:\n", + "1. Save your API KEY into DATAHERALD_API_KEY env variable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "961b3689", + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "pip install dataherald" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "34bb5968", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"DATAHERALD_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ac4910f8", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.utilities.dataherald import DataheraldAPIWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "84b8f773", + "metadata": {}, + "outputs": [], + "source": [ + "dataherald = DataheraldAPIWrapper(db_connection_id=\"65fb766367dd22c99ce1a12d\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "068991a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'select COUNT(*) from employees'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataherald.run(\"How many employees are in the company?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "53f3bc57609c7a84333bb558594977aa5b4026b1d6070b93987956689e367341" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py index 53ddd31e7b2..48db105fcd9 100644 --- a/libs/community/langchain_community/tools/__init__.py +++ b/libs/community/langchain_community/tools/__init__.py @@ -485,6 +485,7 @@ _module_lookup = { "ConneryAction": "langchain_community.tools.connery", "CopyFileTool": "langchain_community.tools.file_management", "CurrentWebPageTool": "langchain_community.tools.playwright", + "DataheraldTextToSQL": "langchain_community.tools.dataherald.tool", "DeleteFileTool": "langchain_community.tools.file_management", "DuckDuckGoSearchResults": "langchain_community.tools.ddg_search.tool", "DuckDuckGoSearchRun": "langchain_community.tools.ddg_search.tool", diff --git a/libs/community/langchain_community/tools/dataherald/__init__.py b/libs/community/langchain_community/tools/dataherald/__init__.py new file mode 100644 index 00000000000..319d19b8e7f --- /dev/null +++ b/libs/community/langchain_community/tools/dataherald/__init__.py @@ -0,0 +1,8 @@ +"""Dataherald API toolkit.""" + + +from langchain_community.tools.dataherald.tool import DataheraldTextToSQL + +__all__ = [ + "DataheraldTextToSQL", +] diff --git a/libs/community/langchain_community/tools/dataherald/tool.py b/libs/community/langchain_community/tools/dataherald/tool.py new file mode 100644 index 00000000000..90c4cef7102 --- /dev/null +++ b/libs/community/langchain_community/tools/dataherald/tool.py @@ -0,0 +1,36 @@ +"""Tool for the Dataherald Hosted API""" + +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import BaseTool + +from langchain_community.utilities.dataherald import DataheraldAPIWrapper + + +class DataheraldTextToSQLInput(BaseModel): + prompt: str = Field( + description="Natural language query to be translated to a SQL query." + ) + + +class DataheraldTextToSQL(BaseTool): + """Tool that queries using the Dataherald SDK.""" + + name: str = "dataherald" + description: str = ( + "A wrapper around Dataherald. " + "Text to SQL. " + "Input should be a prompt and an existing db_connection_id" + ) + api_wrapper: DataheraldAPIWrapper + args_schema: Type[BaseModel] = DataheraldTextToSQLInput + + def _run( + self, + prompt: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the Dataherald tool.""" + return self.api_wrapper.run(prompt) diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index 582eb856dd6..148447c8446 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -230,6 +230,7 @@ _module_lookup = { "BibtexparserWrapper": "langchain_community.utilities.bibtex", "BingSearchAPIWrapper": "langchain_community.utilities.bing_search", "BraveSearchWrapper": "langchain_community.utilities.brave_search", + "DataheraldAPIWrapper": "langchain_community.utilities.dataherald", "DriaAPIWrapper": "langchain_community.utilities.dria_index", "DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search", "GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query", diff --git a/libs/community/langchain_community/utilities/dataherald.py b/libs/community/langchain_community/utilities/dataherald.py new file mode 100644 index 00000000000..a085e23bb15 --- /dev/null +++ b/libs/community/langchain_community/utilities/dataherald.py @@ -0,0 +1,67 @@ +"""Util that calls Dataherald.""" +from typing import Any, Dict, Optional + +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.utils import get_from_dict_or_env + + +class DataheraldAPIWrapper(BaseModel): + """Wrapper for Dataherald. + + Docs for using: + + 1. Go to dataherald and sign up + 2. Create an API key + 3. Save your API key into DATAHERALD_API_KEY env variable + 4. pip install dataherald + + """ + + dataherald_client: Any #: :meta private: + db_connection_id: str + dataherald_api_key: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + dataherald_api_key = get_from_dict_or_env( + values, "dataherald_api_key", "DATAHERALD_API_KEY" + ) + values["dataherald_api_key"] = dataherald_api_key + + try: + import dataherald + + except ImportError: + raise ImportError( + "dataherald is not installed. " + "Please install it with `pip install dataherald`" + ) + + client = dataherald.Dataherald(api_key=dataherald_api_key) + values["dataherald_client"] = client + + return values + + def run(self, prompt: str) -> str: + """Generate a sql query through Dataherald and parse result.""" + from dataherald.types.sql_generation_create_params import Prompt + + prompt_obj = Prompt(text=prompt, db_connection_id=self.db_connection_id) + res = self.dataherald_client.sql_generations.create(prompt=prompt_obj) + + try: + answer = res.sql + if not answer: + # We don't want to return the assumption alone if answer is empty + return "No answer" + else: + return f"Answer: {answer}" + + except StopIteration: + return "Dataherald wasn't able to answer it" diff --git a/libs/community/tests/integration_tests/utilities/test_dataherald_api.py b/libs/community/tests/integration_tests/utilities/test_dataherald_api.py new file mode 100644 index 00000000000..8556dad408f --- /dev/null +++ b/libs/community/tests/integration_tests/utilities/test_dataherald_api.py @@ -0,0 +1,9 @@ +"""Integration test for Dataherald API Wrapper.""" +from langchain_community.utilities.dataherald import DataheraldAPIWrapper + + +def test_call() -> None: + """Test that call gives the correct answer.""" + search = DataheraldAPIWrapper(db_connection_id="65fb766367dd22c99ce1a12d") + output = search.run("How many employees are in the company?") + assert "Answer: SELECT \n COUNT(*) FROM \n employees" in output diff --git a/libs/community/tests/unit_tests/tools/test_imports.py b/libs/community/tests/unit_tests/tools/test_imports.py index 2d8e8754c4f..c6ae50302d1 100644 --- a/libs/community/tests/unit_tests/tools/test_imports.py +++ b/libs/community/tests/unit_tests/tools/test_imports.py @@ -36,6 +36,7 @@ EXPECTED_ALL = [ "ConneryAction", "CopyFileTool", "CurrentWebPageTool", + "DataheraldTextToSQL", "DeleteFileTool", "DuckDuckGoSearchResults", "DuckDuckGoSearchRun", diff --git a/libs/community/tests/unit_tests/tools/test_public_api.py b/libs/community/tests/unit_tests/tools/test_public_api.py index 5a4d2af51ec..a4fe6e89a42 100644 --- a/libs/community/tests/unit_tests/tools/test_public_api.py +++ b/libs/community/tests/unit_tests/tools/test_public_api.py @@ -37,6 +37,7 @@ _EXPECTED = [ "ConneryAction", "CopyFileTool", "CurrentWebPageTool", + "DataheraldTextToSQL", "DeleteFileTool", "DuckDuckGoSearchResults", "DuckDuckGoSearchRun", diff --git a/libs/community/tests/unit_tests/utilities/test_imports.py b/libs/community/tests/unit_tests/utilities/test_imports.py index 5adb6f9a58e..c6e2951b9b5 100644 --- a/libs/community/tests/unit_tests/utilities/test_imports.py +++ b/libs/community/tests/unit_tests/utilities/test_imports.py @@ -8,6 +8,7 @@ EXPECTED_ALL = [ "BibtexparserWrapper", "BingSearchAPIWrapper", "BraveSearchWrapper", + "DataheraldAPIWrapper", "DuckDuckGoSearchAPIWrapper", "DriaAPIWrapper", "GoldenQueryAPIWrapper",