mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
Refactored sql_database
(#7945)
The `sql_database.py` is unnecessarily placed in the root code folder. A similar code is usually placed in the `utilities/`. As a byproduct of this placement, the sql_database is [placed on the top level of classes in the API Reference](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.sql_database) which is confusing and not correct. - moved the `sql_database.py` from the root code folder to the `utilities/` @baskaryan --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
dc9d6cadab
commit
ae8bc9e830
@ -43,7 +43,6 @@ from langchain.prompts import (
|
|||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema.prompt_template import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
|
||||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||||
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
|
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
|
||||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||||
@ -51,6 +50,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
|||||||
from langchain.utilities.powerbi import PowerBIDataset
|
from langchain.utilities.powerbi import PowerBIDataset
|
||||||
from langchain.utilities.searx_search import SearxSearchWrapper
|
from langchain.utilities.searx_search import SearxSearchWrapper
|
||||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||||
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||||
|
@ -5,7 +5,6 @@ from pydantic import Field
|
|||||||
|
|
||||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.sql_database import SQLDatabase
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.tools.sql_database.tool import (
|
from langchain.tools.sql_database.tool import (
|
||||||
InfoSQLDatabaseTool,
|
InfoSQLDatabaseTool,
|
||||||
@ -13,6 +12,7 @@ from langchain.tools.sql_database.tool import (
|
|||||||
QuerySQLCheckerTool,
|
QuerySQLCheckerTool,
|
||||||
QuerySQLDataBaseTool,
|
QuerySQLDataBaseTool,
|
||||||
)
|
)
|
||||||
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabaseToolkit(BaseToolkit):
|
class SQLDatabaseToolkit(BaseToolkit):
|
||||||
|
@ -13,8 +13,8 @@ from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PRO
|
|||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.sql_database import SQLDatabase
|
|
||||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
@ -1,445 +1,4 @@
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""Keep here for backwards compatibility."""
|
||||||
from __future__ import annotations
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
import warnings
|
__all__ = ["SQLDatabase"]
|
||||||
from typing import Any, Iterable, List, Optional
|
|
||||||
|
|
||||||
import sqlalchemy
|
|
||||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
|
||||||
from sqlalchemy.schema import CreateTable
|
|
||||||
|
|
||||||
from langchain.utils import get_from_env
|
|
||||||
|
|
||||||
|
|
||||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
|
||||||
return (
|
|
||||||
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
|
||||||
f' Columns: {str(index["column_names"])}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
|
||||||
"""
|
|
||||||
Truncate a string to a certain number of words, based on the max string
|
|
||||||
length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not isinstance(content, str) or length <= 0:
|
|
||||||
return content
|
|
||||||
|
|
||||||
if len(content) <= length:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabase:
|
|
||||||
"""SQLAlchemy wrapper around a database."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
engine: Engine,
|
|
||||||
schema: Optional[str] = None,
|
|
||||||
metadata: Optional[MetaData] = None,
|
|
||||||
ignore_tables: Optional[List[str]] = None,
|
|
||||||
include_tables: Optional[List[str]] = None,
|
|
||||||
sample_rows_in_table_info: int = 3,
|
|
||||||
indexes_in_table_info: bool = False,
|
|
||||||
custom_table_info: Optional[dict] = None,
|
|
||||||
view_support: bool = False,
|
|
||||||
max_string_length: int = 300,
|
|
||||||
):
|
|
||||||
"""Create engine from database URI."""
|
|
||||||
self._engine = engine
|
|
||||||
self._schema = schema
|
|
||||||
if include_tables and ignore_tables:
|
|
||||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
|
||||||
|
|
||||||
self._inspector = inspect(self._engine)
|
|
||||||
|
|
||||||
# including view support by adding the views as well as tables to the all
|
|
||||||
# tables list if view_support is True
|
|
||||||
self._all_tables = set(
|
|
||||||
self._inspector.get_table_names(schema=schema)
|
|
||||||
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
|
|
||||||
)
|
|
||||||
|
|
||||||
self._include_tables = set(include_tables) if include_tables else set()
|
|
||||||
if self._include_tables:
|
|
||||||
missing_tables = self._include_tables - self._all_tables
|
|
||||||
if missing_tables:
|
|
||||||
raise ValueError(
|
|
||||||
f"include_tables {missing_tables} not found in database"
|
|
||||||
)
|
|
||||||
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
||||||
if self._ignore_tables:
|
|
||||||
missing_tables = self._ignore_tables - self._all_tables
|
|
||||||
if missing_tables:
|
|
||||||
raise ValueError(
|
|
||||||
f"ignore_tables {missing_tables} not found in database"
|
|
||||||
)
|
|
||||||
usable_tables = self.get_usable_table_names()
|
|
||||||
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
||||||
|
|
||||||
if not isinstance(sample_rows_in_table_info, int):
|
|
||||||
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
||||||
|
|
||||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
||||||
self._indexes_in_table_info = indexes_in_table_info
|
|
||||||
|
|
||||||
self._custom_table_info = custom_table_info
|
|
||||||
if self._custom_table_info:
|
|
||||||
if not isinstance(self._custom_table_info, dict):
|
|
||||||
raise TypeError(
|
|
||||||
"table_info must be a dictionary with table names as keys and the "
|
|
||||||
"desired table info as values"
|
|
||||||
)
|
|
||||||
# only keep the tables that are also present in the database
|
|
||||||
intersection = set(self._custom_table_info).intersection(self._all_tables)
|
|
||||||
self._custom_table_info = dict(
|
|
||||||
(table, self._custom_table_info[table])
|
|
||||||
for table in self._custom_table_info
|
|
||||||
if table in intersection
|
|
||||||
)
|
|
||||||
|
|
||||||
self._max_string_length = max_string_length
|
|
||||||
|
|
||||||
self._metadata = metadata or MetaData()
|
|
||||||
# including view support if view_support = true
|
|
||||||
self._metadata.reflect(
|
|
||||||
views=view_support,
|
|
||||||
bind=self._engine,
|
|
||||||
only=list(self._usable_tables),
|
|
||||||
schema=self._schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_uri(
|
|
||||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
||||||
) -> SQLDatabase:
|
|
||||||
"""Construct a SQLAlchemy engine from URI."""
|
|
||||||
_engine_args = engine_args or {}
|
|
||||||
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_databricks(
|
|
||||||
cls,
|
|
||||||
catalog: str,
|
|
||||||
schema: str,
|
|
||||||
host: Optional[str] = None,
|
|
||||||
api_token: Optional[str] = None,
|
|
||||||
warehouse_id: Optional[str] = None,
|
|
||||||
cluster_id: Optional[str] = None,
|
|
||||||
engine_args: Optional[dict] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> SQLDatabase:
|
|
||||||
"""
|
|
||||||
Class method to create an SQLDatabase instance from a Databricks connection.
|
|
||||||
This method requires the 'databricks-sql-connector' package. If not installed,
|
|
||||||
it can be added using `pip install databricks-sql-connector`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
catalog (str): The catalog name in the Databricks database.
|
|
||||||
schema (str): The schema name in the catalog.
|
|
||||||
host (Optional[str]): The Databricks workspace hostname, excluding
|
|
||||||
'https://' part. If not provided, it attempts to fetch from the
|
|
||||||
environment variable 'DATABRICKS_HOST'. If still unavailable and if
|
|
||||||
running in a Databricks notebook, it defaults to the current workspace
|
|
||||||
hostname. Defaults to None.
|
|
||||||
api_token (Optional[str]): The Databricks personal access token for
|
|
||||||
accessing the Databricks SQL warehouse or the cluster. If not provided,
|
|
||||||
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
|
|
||||||
and running in a Databricks notebook, a temporary token for the current
|
|
||||||
user is generated. Defaults to None.
|
|
||||||
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
|
|
||||||
provided, the method configures the connection to use this warehouse.
|
|
||||||
Cannot be used with 'cluster_id'. Defaults to None.
|
|
||||||
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
|
|
||||||
provided, the method configures the connection to use this cluster.
|
|
||||||
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
|
|
||||||
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
|
|
||||||
cluster the notebook is attached to. Defaults to None.
|
|
||||||
engine_args (Optional[dict]): The arguments to be used when connecting
|
|
||||||
Databricks. Defaults to None.
|
|
||||||
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
||||||
Databricks connection details.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If 'databricks-sql-connector' is not found, or if both
|
|
||||||
'warehouse_id' and 'cluster_id' are provided, or if neither
|
|
||||||
'warehouse_id' nor 'cluster_id' are provided and it's not executing
|
|
||||||
inside a Databricks notebook.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from databricks import sql # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"databricks-sql-connector package not found, please install with"
|
|
||||||
" `pip install databricks-sql-connector`"
|
|
||||||
)
|
|
||||||
context = None
|
|
||||||
try:
|
|
||||||
from dbruntime.databricks_repl_context import get_context
|
|
||||||
|
|
||||||
context = get_context()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
default_host = context.browserHostName if context else None
|
|
||||||
if host is None:
|
|
||||||
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
|
||||||
|
|
||||||
default_api_token = context.apiToken if context else None
|
|
||||||
if api_token is None:
|
|
||||||
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
|
||||||
|
|
||||||
if warehouse_id is None and cluster_id is None:
|
|
||||||
if context:
|
|
||||||
cluster_id = context.clusterId
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Need to provide either 'warehouse_id' or 'cluster_id'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if warehouse_id and cluster_id:
|
|
||||||
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
|
|
||||||
|
|
||||||
if warehouse_id:
|
|
||||||
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
|
|
||||||
else:
|
|
||||||
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
|
|
||||||
|
|
||||||
uri = (
|
|
||||||
f"databricks://token:{api_token}@{host}?"
|
|
||||||
f"http_path={http_path}&catalog={catalog}&schema={schema}"
|
|
||||||
)
|
|
||||||
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cnosdb(
|
|
||||||
cls,
|
|
||||||
url: str = "127.0.0.1:8902",
|
|
||||||
user: str = "root",
|
|
||||||
password: str = "",
|
|
||||||
tenant: str = "cnosdb",
|
|
||||||
database: str = "public",
|
|
||||||
) -> SQLDatabase:
|
|
||||||
"""
|
|
||||||
Class method to create an SQLDatabase instance from a CnosDB connection.
|
|
||||||
This method requires the 'cnos-connector' package. If not installed, it
|
|
||||||
can be added using `pip install cnos-connector`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The HTTP connection host name and port number of the CnosDB
|
|
||||||
service, excluding "http://" or "https://", with a default value
|
|
||||||
of "127.0.0.1:8902".
|
|
||||||
user (str): The username used to connect to the CnosDB service, with a
|
|
||||||
default value of "root".
|
|
||||||
password (str): The password of the user connecting to the CnosDB service,
|
|
||||||
with a default value of "".
|
|
||||||
tenant (str): The name of the tenant used to connect to the CnosDB service,
|
|
||||||
with a default value of "cnosdb".
|
|
||||||
database (str): The name of the database in the CnosDB tenant.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
||||||
CnosDB connection details.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from cnosdb_connector import make_cnosdb_langchain_uri
|
|
||||||
|
|
||||||
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
|
|
||||||
return cls.from_uri(database_uri=uri)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"cnos-connector package not found, please install with"
|
|
||||||
" `pip install cnos-connector`"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dialect(self) -> str:
|
|
||||||
"""Return string representation of dialect to use."""
|
|
||||||
return self._engine.dialect.name
|
|
||||||
|
|
||||||
def get_usable_table_names(self) -> Iterable[str]:
|
|
||||||
"""Get names of tables available."""
|
|
||||||
if self._include_tables:
|
|
||||||
return sorted(self._include_tables)
|
|
||||||
return sorted(self._all_tables - self._ignore_tables)
|
|
||||||
|
|
||||||
def get_table_names(self) -> Iterable[str]:
|
|
||||||
"""Get names of tables available."""
|
|
||||||
warnings.warn(
|
|
||||||
"This method is deprecated - please use `get_usable_table_names`."
|
|
||||||
)
|
|
||||||
return self.get_usable_table_names()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def table_info(self) -> str:
|
|
||||||
"""Information about all tables in the database."""
|
|
||||||
return self.get_table_info()
|
|
||||||
|
|
||||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
||||||
"""Get information about specified tables.
|
|
||||||
|
|
||||||
Follows best practices as specified in: Rajkumar et al, 2022
|
|
||||||
(https://arxiv.org/abs/2204.00498)
|
|
||||||
|
|
||||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
||||||
appended to each table description. This can increase performance as
|
|
||||||
demonstrated in the paper.
|
|
||||||
"""
|
|
||||||
all_table_names = self.get_usable_table_names()
|
|
||||||
if table_names is not None:
|
|
||||||
missing_tables = set(table_names).difference(all_table_names)
|
|
||||||
if missing_tables:
|
|
||||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
|
||||||
all_table_names = table_names
|
|
||||||
|
|
||||||
meta_tables = [
|
|
||||||
tbl
|
|
||||||
for tbl in self._metadata.sorted_tables
|
|
||||||
if tbl.name in set(all_table_names)
|
|
||||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
|
||||||
]
|
|
||||||
|
|
||||||
tables = []
|
|
||||||
for table in meta_tables:
|
|
||||||
if self._custom_table_info and table.name in self._custom_table_info:
|
|
||||||
tables.append(self._custom_table_info[table.name])
|
|
||||||
continue
|
|
||||||
|
|
||||||
# add create table command
|
|
||||||
create_table = str(CreateTable(table).compile(self._engine))
|
|
||||||
table_info = f"{create_table.rstrip()}"
|
|
||||||
has_extra_info = (
|
|
||||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
|
||||||
)
|
|
||||||
if has_extra_info:
|
|
||||||
table_info += "\n\n/*"
|
|
||||||
if self._indexes_in_table_info:
|
|
||||||
table_info += f"\n{self._get_table_indexes(table)}\n"
|
|
||||||
if self._sample_rows_in_table_info:
|
|
||||||
table_info += f"\n{self._get_sample_rows(table)}\n"
|
|
||||||
if has_extra_info:
|
|
||||||
table_info += "*/"
|
|
||||||
tables.append(table_info)
|
|
||||||
tables.sort()
|
|
||||||
final_str = "\n\n".join(tables)
|
|
||||||
return final_str
|
|
||||||
|
|
||||||
def _get_table_indexes(self, table: Table) -> str:
|
|
||||||
indexes = self._inspector.get_indexes(table.name)
|
|
||||||
indexes_formatted = "\n".join(map(_format_index, indexes))
|
|
||||||
return f"Table Indexes:\n{indexes_formatted}"
|
|
||||||
|
|
||||||
def _get_sample_rows(self, table: Table) -> str:
|
|
||||||
# build the select command
|
|
||||||
command = select(table).limit(self._sample_rows_in_table_info)
|
|
||||||
|
|
||||||
# save the columns in string format
|
|
||||||
columns_str = "\t".join([col.name for col in table.columns])
|
|
||||||
|
|
||||||
try:
|
|
||||||
# get the sample rows
|
|
||||||
with self._engine.connect() as connection:
|
|
||||||
sample_rows_result = connection.execute(command) # type: ignore
|
|
||||||
# shorten values in the sample rows
|
|
||||||
sample_rows = list(
|
|
||||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
|
||||||
)
|
|
||||||
|
|
||||||
# save the sample rows in string format
|
|
||||||
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
|
||||||
|
|
||||||
# in some dialects when there are no rows in the table a
|
|
||||||
# 'ProgrammingError' is returned
|
|
||||||
except ProgrammingError:
|
|
||||||
sample_rows_str = ""
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
|
||||||
f"{columns_str}\n"
|
|
||||||
f"{sample_rows_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self, command: str, fetch: str = "all") -> str:
|
|
||||||
"""Execute a SQL command and return a string representing the results.
|
|
||||||
|
|
||||||
If the statement returns rows, a string of the results is returned.
|
|
||||||
If the statement returns no rows, an empty string is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
with self._engine.begin() as connection:
|
|
||||||
if self._schema is not None:
|
|
||||||
if self.dialect == "snowflake":
|
|
||||||
connection.exec_driver_sql(
|
|
||||||
f"ALTER SESSION SET search_path='{self._schema}'"
|
|
||||||
)
|
|
||||||
elif self.dialect == "bigquery":
|
|
||||||
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
|
||||||
else:
|
|
||||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
|
||||||
cursor = connection.execute(text(command))
|
|
||||||
if cursor.returns_rows:
|
|
||||||
if fetch == "all":
|
|
||||||
result = cursor.fetchall()
|
|
||||||
elif fetch == "one":
|
|
||||||
result = cursor.fetchone() # type: ignore
|
|
||||||
else:
|
|
||||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
|
||||||
|
|
||||||
# Convert columns values to string to avoid issues with sqlalchmey
|
|
||||||
# trunacating text
|
|
||||||
if isinstance(result, list):
|
|
||||||
return str(
|
|
||||||
[
|
|
||||||
tuple(
|
|
||||||
truncate_word(c, length=self._max_string_length)
|
|
||||||
for c in r
|
|
||||||
)
|
|
||||||
for r in result
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return str(
|
|
||||||
tuple(
|
|
||||||
truncate_word(c, length=self._max_string_length) for c in result
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
|
||||||
"""Get information about specified tables.
|
|
||||||
|
|
||||||
Follows best practices as specified in: Rajkumar et al, 2022
|
|
||||||
(https://arxiv.org/abs/2204.00498)
|
|
||||||
|
|
||||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
||||||
appended to each table description. This can increase performance as
|
|
||||||
demonstrated in the paper.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.get_table_info(table_names)
|
|
||||||
except ValueError as e:
|
|
||||||
"""Format the error message"""
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
def run_no_throw(self, command: str, fetch: str = "all") -> str:
|
|
||||||
"""Execute a SQL command and return a string representing the results.
|
|
||||||
|
|
||||||
If the statement returns rows, a string of the results is returned.
|
|
||||||
If the statement returns no rows, an empty string is returned.
|
|
||||||
|
|
||||||
If the statement throws an error, the error message is returned.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.run(command, fetch)
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
"""Format the error message"""
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from langchain.utilities.scenexplain import SceneXplainAPIWrapper
|
|||||||
from langchain.utilities.searx_search import SearxSearchWrapper
|
from langchain.utilities.searx_search import SearxSearchWrapper
|
||||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||||
from langchain.utilities.spark_sql import SparkSQL
|
from langchain.utilities.spark_sql import SparkSQL
|
||||||
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
from langchain.utilities.twilio import TwilioAPIWrapper
|
from langchain.utilities.twilio import TwilioAPIWrapper
|
||||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||||
@ -56,6 +57,7 @@ __all__ = [
|
|||||||
"SearxSearchWrapper",
|
"SearxSearchWrapper",
|
||||||
"SerpAPIWrapper",
|
"SerpAPIWrapper",
|
||||||
"SparkSQL",
|
"SparkSQL",
|
||||||
|
"SQLDatabase",
|
||||||
"TextRequestsWrapper",
|
"TextRequestsWrapper",
|
||||||
"TwilioAPIWrapper",
|
"TwilioAPIWrapper",
|
||||||
"WikipediaAPIWrapper",
|
"WikipediaAPIWrapper",
|
||||||
|
445
langchain/utilities/sql_database.py
Normal file
445
langchain/utilities/sql_database.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
|
from sqlalchemy.schema import CreateTable
|
||||||
|
|
||||||
|
from langchain.utils import get_from_env
|
||||||
|
|
||||||
|
|
||||||
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||||
|
return (
|
||||||
|
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||||
|
f' Columns: {str(index["column_names"])}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
||||||
|
"""
|
||||||
|
Truncate a string to a certain number of words, based on the max string
|
||||||
|
length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(content, str) or length <= 0:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if len(content) <= length:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
||||||
|
|
||||||
|
|
||||||
|
class SQLDatabase:
|
||||||
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: Engine,
|
||||||
|
schema: Optional[str] = None,
|
||||||
|
metadata: Optional[MetaData] = None,
|
||||||
|
ignore_tables: Optional[List[str]] = None,
|
||||||
|
include_tables: Optional[List[str]] = None,
|
||||||
|
sample_rows_in_table_info: int = 3,
|
||||||
|
indexes_in_table_info: bool = False,
|
||||||
|
custom_table_info: Optional[dict] = None,
|
||||||
|
view_support: bool = False,
|
||||||
|
max_string_length: int = 300,
|
||||||
|
):
|
||||||
|
"""Create engine from database URI."""
|
||||||
|
self._engine = engine
|
||||||
|
self._schema = schema
|
||||||
|
if include_tables and ignore_tables:
|
||||||
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||||
|
|
||||||
|
self._inspector = inspect(self._engine)
|
||||||
|
|
||||||
|
# including view support by adding the views as well as tables to the all
|
||||||
|
# tables list if view_support is True
|
||||||
|
self._all_tables = set(
|
||||||
|
self._inspector.get_table_names(schema=schema)
|
||||||
|
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
self._include_tables = set(include_tables) if include_tables else set()
|
||||||
|
if self._include_tables:
|
||||||
|
missing_tables = self._include_tables - self._all_tables
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(
|
||||||
|
f"include_tables {missing_tables} not found in database"
|
||||||
|
)
|
||||||
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||||
|
if self._ignore_tables:
|
||||||
|
missing_tables = self._ignore_tables - self._all_tables
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(
|
||||||
|
f"ignore_tables {missing_tables} not found in database"
|
||||||
|
)
|
||||||
|
usable_tables = self.get_usable_table_names()
|
||||||
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||||
|
|
||||||
|
if not isinstance(sample_rows_in_table_info, int):
|
||||||
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
||||||
|
|
||||||
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||||
|
self._indexes_in_table_info = indexes_in_table_info
|
||||||
|
|
||||||
|
self._custom_table_info = custom_table_info
|
||||||
|
if self._custom_table_info:
|
||||||
|
if not isinstance(self._custom_table_info, dict):
|
||||||
|
raise TypeError(
|
||||||
|
"table_info must be a dictionary with table names as keys and the "
|
||||||
|
"desired table info as values"
|
||||||
|
)
|
||||||
|
# only keep the tables that are also present in the database
|
||||||
|
intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||||
|
self._custom_table_info = dict(
|
||||||
|
(table, self._custom_table_info[table])
|
||||||
|
for table in self._custom_table_info
|
||||||
|
if table in intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
self._max_string_length = max_string_length
|
||||||
|
|
||||||
|
self._metadata = metadata or MetaData()
|
||||||
|
# including view support if view_support = true
|
||||||
|
self._metadata.reflect(
|
||||||
|
views=view_support,
|
||||||
|
bind=self._engine,
|
||||||
|
only=list(self._usable_tables),
|
||||||
|
schema=self._schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_uri(
|
||||||
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> SQLDatabase:
|
||||||
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
|
_engine_args = engine_args or {}
|
||||||
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_databricks(
|
||||||
|
cls,
|
||||||
|
catalog: str,
|
||||||
|
schema: str,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
api_token: Optional[str] = None,
|
||||||
|
warehouse_id: Optional[str] = None,
|
||||||
|
cluster_id: Optional[str] = None,
|
||||||
|
engine_args: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> SQLDatabase:
|
||||||
|
"""
|
||||||
|
Class method to create an SQLDatabase instance from a Databricks connection.
|
||||||
|
This method requires the 'databricks-sql-connector' package. If not installed,
|
||||||
|
it can be added using `pip install databricks-sql-connector`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
catalog (str): The catalog name in the Databricks database.
|
||||||
|
schema (str): The schema name in the catalog.
|
||||||
|
host (Optional[str]): The Databricks workspace hostname, excluding
|
||||||
|
'https://' part. If not provided, it attempts to fetch from the
|
||||||
|
environment variable 'DATABRICKS_HOST'. If still unavailable and if
|
||||||
|
running in a Databricks notebook, it defaults to the current workspace
|
||||||
|
hostname. Defaults to None.
|
||||||
|
api_token (Optional[str]): The Databricks personal access token for
|
||||||
|
accessing the Databricks SQL warehouse or the cluster. If not provided,
|
||||||
|
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
|
||||||
|
and running in a Databricks notebook, a temporary token for the current
|
||||||
|
user is generated. Defaults to None.
|
||||||
|
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
|
||||||
|
provided, the method configures the connection to use this warehouse.
|
||||||
|
Cannot be used with 'cluster_id'. Defaults to None.
|
||||||
|
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
|
||||||
|
provided, the method configures the connection to use this cluster.
|
||||||
|
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
|
||||||
|
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
|
||||||
|
cluster the notebook is attached to. Defaults to None.
|
||||||
|
engine_args (Optional[dict]): The arguments to be used when connecting
|
||||||
|
Databricks. Defaults to None.
|
||||||
|
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
||||||
|
Databricks connection details.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If 'databricks-sql-connector' is not found, or if both
|
||||||
|
'warehouse_id' and 'cluster_id' are provided, or if neither
|
||||||
|
'warehouse_id' nor 'cluster_id' are provided and it's not executing
|
||||||
|
inside a Databricks notebook.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from databricks import sql # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"databricks-sql-connector package not found, please install with"
|
||||||
|
" `pip install databricks-sql-connector`"
|
||||||
|
)
|
||||||
|
context = None
|
||||||
|
try:
|
||||||
|
from dbruntime.databricks_repl_context import get_context
|
||||||
|
|
||||||
|
context = get_context()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
default_host = context.browserHostName if context else None
|
||||||
|
if host is None:
|
||||||
|
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
||||||
|
|
||||||
|
default_api_token = context.apiToken if context else None
|
||||||
|
if api_token is None:
|
||||||
|
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
||||||
|
|
||||||
|
if warehouse_id is None and cluster_id is None:
|
||||||
|
if context:
|
||||||
|
cluster_id = context.clusterId
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Need to provide either 'warehouse_id' or 'cluster_id'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if warehouse_id and cluster_id:
|
||||||
|
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
|
||||||
|
|
||||||
|
if warehouse_id:
|
||||||
|
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
|
||||||
|
else:
|
||||||
|
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
|
||||||
|
|
||||||
|
uri = (
|
||||||
|
f"databricks://token:{api_token}@{host}?"
|
||||||
|
f"http_path={http_path}&catalog={catalog}&schema={schema}"
|
||||||
|
)
|
||||||
|
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cnosdb(
|
||||||
|
cls,
|
||||||
|
url: str = "127.0.0.1:8902",
|
||||||
|
user: str = "root",
|
||||||
|
password: str = "",
|
||||||
|
tenant: str = "cnosdb",
|
||||||
|
database: str = "public",
|
||||||
|
) -> SQLDatabase:
|
||||||
|
"""
|
||||||
|
Class method to create an SQLDatabase instance from a CnosDB connection.
|
||||||
|
This method requires the 'cnos-connector' package. If not installed, it
|
||||||
|
can be added using `pip install cnos-connector`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The HTTP connection host name and port number of the CnosDB
|
||||||
|
service, excluding "http://" or "https://", with a default value
|
||||||
|
of "127.0.0.1:8902".
|
||||||
|
user (str): The username used to connect to the CnosDB service, with a
|
||||||
|
default value of "root".
|
||||||
|
password (str): The password of the user connecting to the CnosDB service,
|
||||||
|
with a default value of "".
|
||||||
|
tenant (str): The name of the tenant used to connect to the CnosDB service,
|
||||||
|
with a default value of "cnosdb".
|
||||||
|
database (str): The name of the database in the CnosDB tenant.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
||||||
|
CnosDB connection details.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from cnosdb_connector import make_cnosdb_langchain_uri
|
||||||
|
|
||||||
|
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
|
||||||
|
return cls.from_uri(database_uri=uri)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"cnos-connector package not found, please install with"
|
||||||
|
" `pip install cnos-connector`"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dialect(self) -> str:
|
||||||
|
"""Return string representation of dialect to use."""
|
||||||
|
return self._engine.dialect.name
|
||||||
|
|
||||||
|
def get_usable_table_names(self) -> Iterable[str]:
|
||||||
|
"""Get names of tables available."""
|
||||||
|
if self._include_tables:
|
||||||
|
return sorted(self._include_tables)
|
||||||
|
return sorted(self._all_tables - self._ignore_tables)
|
||||||
|
|
||||||
|
def get_table_names(self) -> Iterable[str]:
|
||||||
|
"""Get names of tables available."""
|
||||||
|
warnings.warn(
|
||||||
|
"This method is deprecated - please use `get_usable_table_names`."
|
||||||
|
)
|
||||||
|
return self.get_usable_table_names()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table_info(self) -> str:
|
||||||
|
"""Information about all tables in the database."""
|
||||||
|
return self.get_table_info()
|
||||||
|
|
||||||
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
"""Get information about specified tables.
|
||||||
|
|
||||||
|
Follows best practices as specified in: Rajkumar et al, 2022
|
||||||
|
(https://arxiv.org/abs/2204.00498)
|
||||||
|
|
||||||
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||||
|
appended to each table description. This can increase performance as
|
||||||
|
demonstrated in the paper.
|
||||||
|
"""
|
||||||
|
all_table_names = self.get_usable_table_names()
|
||||||
|
if table_names is not None:
|
||||||
|
missing_tables = set(table_names).difference(all_table_names)
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||||
|
all_table_names = table_names
|
||||||
|
|
||||||
|
meta_tables = [
|
||||||
|
tbl
|
||||||
|
for tbl in self._metadata.sorted_tables
|
||||||
|
if tbl.name in set(all_table_names)
|
||||||
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||||
|
]
|
||||||
|
|
||||||
|
tables = []
|
||||||
|
for table in meta_tables:
|
||||||
|
if self._custom_table_info and table.name in self._custom_table_info:
|
||||||
|
tables.append(self._custom_table_info[table.name])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# add create table command
|
||||||
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
|
table_info = f"{create_table.rstrip()}"
|
||||||
|
has_extra_info = (
|
||||||
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||||
|
)
|
||||||
|
if has_extra_info:
|
||||||
|
table_info += "\n\n/*"
|
||||||
|
if self._indexes_in_table_info:
|
||||||
|
table_info += f"\n{self._get_table_indexes(table)}\n"
|
||||||
|
if self._sample_rows_in_table_info:
|
||||||
|
table_info += f"\n{self._get_sample_rows(table)}\n"
|
||||||
|
if has_extra_info:
|
||||||
|
table_info += "*/"
|
||||||
|
tables.append(table_info)
|
||||||
|
tables.sort()
|
||||||
|
final_str = "\n\n".join(tables)
|
||||||
|
return final_str
|
||||||
|
|
||||||
|
def _get_table_indexes(self, table: Table) -> str:
|
||||||
|
indexes = self._inspector.get_indexes(table.name)
|
||||||
|
indexes_formatted = "\n".join(map(_format_index, indexes))
|
||||||
|
return f"Table Indexes:\n{indexes_formatted}"
|
||||||
|
|
||||||
|
def _get_sample_rows(self, table: Table) -> str:
|
||||||
|
# build the select command
|
||||||
|
command = select(table).limit(self._sample_rows_in_table_info)
|
||||||
|
|
||||||
|
# save the columns in string format
|
||||||
|
columns_str = "\t".join([col.name for col in table.columns])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get the sample rows
|
||||||
|
with self._engine.connect() as connection:
|
||||||
|
sample_rows_result = connection.execute(command) # type: ignore
|
||||||
|
# shorten values in the sample rows
|
||||||
|
sample_rows = list(
|
||||||
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
||||||
|
)
|
||||||
|
|
||||||
|
# save the sample rows in string format
|
||||||
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
||||||
|
|
||||||
|
# in some dialects when there are no rows in the table a
|
||||||
|
# 'ProgrammingError' is returned
|
||||||
|
except ProgrammingError:
|
||||||
|
sample_rows_str = ""
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||||
|
f"{columns_str}\n"
|
||||||
|
f"{sample_rows_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, command: str, fetch: str = "all") -> str:
|
||||||
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
|
If the statement returns rows, a string of the results is returned.
|
||||||
|
If the statement returns no rows, an empty string is returned.
|
||||||
|
|
||||||
|
"""
|
||||||
|
with self._engine.begin() as connection:
|
||||||
|
if self._schema is not None:
|
||||||
|
if self.dialect == "snowflake":
|
||||||
|
connection.exec_driver_sql(
|
||||||
|
f"ALTER SESSION SET search_path='{self._schema}'"
|
||||||
|
)
|
||||||
|
elif self.dialect == "bigquery":
|
||||||
|
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
||||||
|
else:
|
||||||
|
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||||
|
cursor = connection.execute(text(command))
|
||||||
|
if cursor.returns_rows:
|
||||||
|
if fetch == "all":
|
||||||
|
result = cursor.fetchall()
|
||||||
|
elif fetch == "one":
|
||||||
|
result = cursor.fetchone() # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
|
|
||||||
|
# Convert columns values to string to avoid issues with sqlalchmey
|
||||||
|
# trunacating text
|
||||||
|
if isinstance(result, list):
|
||||||
|
return str(
|
||||||
|
[
|
||||||
|
tuple(
|
||||||
|
truncate_word(c, length=self._max_string_length)
|
||||||
|
for c in r
|
||||||
|
)
|
||||||
|
for r in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return str(
|
||||||
|
tuple(
|
||||||
|
truncate_word(c, length=self._max_string_length) for c in result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
"""Get information about specified tables.
|
||||||
|
|
||||||
|
Follows best practices as specified in: Rajkumar et al, 2022
|
||||||
|
(https://arxiv.org/abs/2204.00498)
|
||||||
|
|
||||||
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||||
|
appended to each table description. This can increase performance as
|
||||||
|
demonstrated in the paper.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.get_table_info(table_names)
|
||||||
|
except ValueError as e:
|
||||||
|
"""Format the error message"""
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
def run_no_throw(self, command: str, fetch: str = "all") -> str:
|
||||||
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
|
If the statement returns rows, a string of the results is returned.
|
||||||
|
If the statement returns no rows, an empty string is returned.
|
||||||
|
|
||||||
|
If the statement throws an error, the error message is returned.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.run(command, fetch)
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
"""Format the error message"""
|
||||||
|
return f"Error: {e}"
|
@ -6,7 +6,7 @@ from langchain.chains.sql_database.base import (
|
|||||||
SQLDatabaseSequentialChain,
|
SQLDatabaseSequentialChain,
|
||||||
)
|
)
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from langchain.agents import create_sql_agent
|
from langchain.agents import create_sql_agent
|
||||||
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from sqlalchemy import (
|
|||||||
insert,
|
insert,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.sql_database import SQLDatabase, truncate_word
|
from langchain.utilities.sql_database import SQLDatabase, truncate_word
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from sqlalchemy import (
|
|||||||
schema,
|
schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user