feat(core): Support simple DB query for sdk (#917)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng 2023-12-11 18:33:54 +08:00 committed by GitHub
parent 43190ca333
commit cbba50ab1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 467 additions and 74 deletions

View File

@ -266,19 +266,3 @@ class Config(metaclass=Singleton):
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value
def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value
def set_speak_mode(self, value: bool) -> None:
"""Set the speak mode value."""
self.speak_mode = value
def set_last_plugin_return(self, value: bool) -> None:
"""Set the speak mode value."""
self.last_plugin_return = value

View File

@ -5,13 +5,13 @@ from typing import List
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
LOGDIR,
ROOT_PATH,
)
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.app.base import (
@ -30,7 +30,6 @@ from dbgpt.app.knowledge.api import router as knowledge_router
from dbgpt.app.prompt.api import router as prompt_router
from dbgpt.app.llm_manage.api import router as llm_manage_api
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from dbgpt.app.openapi.base import validation_exception_handler
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
@ -59,7 +58,7 @@ def swagger_monkey_patch(*args, **kwargs):
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
)
@ -79,13 +78,11 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
@ -133,7 +130,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
# Before start
system_app.before_start()
model_name = param.model_name or CFG.LLM_MODEL
param.model_name = model_name
print(param)
embedding_model_name = CFG.EMBEDDING_MODEL
@ -143,8 +141,6 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_name = param.model_name or CFG.LLM_MODEL
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")

View File

@ -52,6 +52,7 @@ class ChatWithDbAutoExecute(BaseChat):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
table_infos = await blocking_func_to_async(

View File

@ -11,7 +11,7 @@ from dbgpt.core.interface.message import (
OnceConversation,
)
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.cache import (
CacheKey,
@ -33,6 +33,7 @@ __ALL__ = [
"PromptTemplate",
"PromptTemplateOperator",
"BaseOutputParser",
"SQLOutputParser",
"Serializable",
"Serializer",
"CacheKey",

View File

@ -53,7 +53,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
@property
def is_empty(self) -> bool:
return not self._data
return self._data is None
async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):

View File

@ -251,3 +251,13 @@ def _parse_model_response(response: ResponseTye):
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex
class SQLOutputParser(BaseOutputParser):
def __init__(self, is_stream_out: bool = False, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)

View File

@ -0,0 +1,23 @@
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
class RetrieverOperator(MapOperator[IN, OUT]):
"""The Abstract Retriever Operator."""
async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
# The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async.
return await self.blocking_func_to_async(self.retrieve, input_value)
@abstractmethod
def retrieve(self, input_value: IN) -> OUT:
"""Retrieve data for input value."""

View File

@ -2,58 +2,104 @@
# -*- coding:utf-8 -*-
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional
from abc import ABC
from typing import Iterable, List, Optional
class BaseConnect(ABC):
def get_connect(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
"""Get all table names"""
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get table info about specified table.
Returns:
str: Table information joined by '\n\n'
"""
pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get index info about specified table.
Args:
table_names (Optional[List[str]]): table names
"""
pass
def get_example_data(self, table: str, count: int = 3):
"""Get example data about specified table.
Not used now.
Args:
table (str): table name
count (int): example data count
"""
pass
def get_database_list(self):
def get_database_list(self) -> List[str]:
"""Get database list.
Returns:
List[str]: database list
"""
pass
def get_database_names(self):
"""Get database names."""
pass
def get_table_comments(self, db_name):
"""Get table comments.
Args:
db_name (str): database name
"""
pass
def run(self, session, command: str, fetch: str = "all") -> List:
def run(self, command: str, fetch: str = "all") -> List:
"""Execute sql command.
Args:
command (str): sql command
fetch (str): fetch type
"""
pass
def run_to_df(self, command: str, fetch: str = "all"):
"""Execute sql command and return dataframe."""
pass
def get_users(self):
pass
"""Get user info."""
return []
def get_grants(self):
pass
"""Get grant info."""
return []
def get_collation(self):
pass
"""Get collation."""
return None
def get_charset(self):
pass
def get_charset(self) -> str:
"""Get character_set of current database."""
return "utf-8"
def get_fields(self, table_name):
"""Get column fields about specified table."""
pass
def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table."""
pass
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
pass
@classmethod
def is_normal_type(cls) -> bool:
"""Return whether the connector is a normal type."""
return True

View File

@ -1,4 +1,4 @@
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel
class DBConfig(BaseModel):

View File

@ -1,3 +1,4 @@
from typing import List, Type
from dbgpt.datasource import ConnectConfigDao
from dbgpt.storage.schema import DBType
from dbgpt.component import SystemApp, ComponentType
@ -21,7 +22,7 @@ from dbgpt.datasource.rdbms.conn_doris import DorisConnect
class ConnectManager:
"""db connect manager"""
def get_all_subclasses(self, cls):
def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += self.get_all_subclasses(subclass)
@ -31,7 +32,7 @@ class ConnectManager:
chat_classes = self.get_all_subclasses(BaseConnect)
support_types = []
for cls in chat_classes:
if cls.db_type:
if cls.db_type and cls.is_normal_type():
support_types.append(DBType.of_db_type(cls.db_type))
return support_types
@ -39,7 +40,7 @@ class ConnectManager:
chat_classes = self.get_all_subclasses(BaseConnect)
result = None
for cls in chat_classes:
if cls.db_type == db_type:
if cls.db_type == db_type and cls.is_normal_type():
result = cls
if not result:
raise ValueError("Unsupported Db Type" + db_type)

View File

View File

@ -0,0 +1,16 @@
from typing import Any
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
from dbgpt.datasource.base import BaseConnect
class DatasourceOperator(MapOperator[str, Any]):
def __init__(self, connection: BaseConnect, **kwargs):
super().__init__(**kwargs)
self._connection = connection
async def map(self, input_value: IN) -> OUT:
return await self.blocking_func_to_async(self.query, input_value)
def query(self, input_value: str) -> Any:
return self._connection.run_to_df(input_value)

View File

@ -4,9 +4,12 @@
import os
from typing import Optional, Any, Iterable
from sqlalchemy import create_engine, text
import tempfile
import logging
from dbgpt.datasource.rdbms.base import RDBMSDatabase
logger = logging.getLogger(__name__)
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
@ -127,3 +130,116 @@ class SQLiteConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});")
return results
class SQLiteTempConnect(SQLiteConnect):
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
def __init__(self, engine, temp_file_path, *args, **kwargs):
super().__init__(engine, *args, **kwargs)
self.temp_file_path = temp_file_path
self._is_closed = False
@classmethod
def create_temporary_db(
cls, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLiteTempConnect":
"""Create a temporary SQLite database with a temporary file.
Examples:
.. code-block:: python
with SQLiteTempConnect.create_temporary_db() as db:
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
Args:
engine_args (Optional[dict]): SQLAlchemy engine arguments.
Returns:
SQLiteTempConnect: A SQLiteTempConnect instance.
"""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
temp_file = tempfile.NamedTemporaryFile(delete=False)
temp_file_path = temp_file.name
temp_file.close()
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
return cls(engine, temp_file_path, **kwargs)
def close(self):
"""Close the connection."""
if not self._is_closed:
if self._engine:
self._engine.dispose()
try:
if os.path.exists(self.temp_file_path):
os.remove(self.temp_file_path)
except Exception as e:
logger.error(f"Error removing temporary database file: {e}")
self._is_closed = True
def create_temp_tables(self, tables_info):
"""Create temporary tables with data.
Examples:
.. code-block:: python
tables_info = {
"test": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 20),
(2, "Jack", 21),
(3, "Alice", 22),
],
},
}
with SQLiteTempConnect.create_temporary_db() as db:
db.create_temp_tables(tables_info)
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id", "name", "age"]
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
Args:
tables_info (dict): A dictionary of table information.
"""
for table_name, table_data in tables_info.items():
columns = ", ".join(
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
)
create_sql = f"CREATE TABLE {table_name} ({columns});"
self.session.execute(text(create_sql))
for row in table_data.get("data", []):
placeholders = ", ".join(
[":param" + str(index) for index, _ in enumerate(row)]
)
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
param_dict = {
"param" + str(index): value for index, value in enumerate(row)
}
self.session.execute(text(insert_sql), param_dict)
self.session.commit()
self._sync_tables_from_db()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
@classmethod
def is_normal_type(cls) -> bool:
return False

View File

View File

@ -0,0 +1,14 @@
from typing import Any
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
def __init__(self, connection: RDBMSDatabase, **kwargs):
super().__init__(**kwargs)
self._connection = connection
def retrieve(self, input_value: Any) -> Any:
summary = _parse_db_summary(self._connection)
return summary

View File

@ -1,5 +1,7 @@
from typing import List
from dbgpt._private.config import Config
from dbgpt.rag.summary.db_summary import DBSummary
from dbgpt.datasource.rdbms.base import RDBMSDatabase
CFG = Config()
@ -36,32 +38,63 @@ class RdbmsSummary(DBSummary):
example:
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
"""
columns = []
for column in self.db._inspector.get_columns(table_name):
if column.get("comment"):
columns.append((f"{column['name']} ({column.get('comment')})"))
else:
columns.append(f"{column['name']}")
column_str = ", ".join(columns)
index_keys = []
for index_key in self.db._inspector.get_indexes(table_name):
key_str = ", ".join(index_key["column_names"])
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
table_str = self.summary_template.format(
table_name=table_name, columns=column_str
)
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f", and index keys: {index_key_str}"
try:
comment = self.db._inspector.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f", and table comment: {comment.get('text')}"
return table_str
return _parse_table_summary(self.db, self.summary_template, table_name)
def table_summaries(self):
"""Get table summaries."""
return self.table_info_summaries
def _parse_db_summary(
conn: RDBMSDatabase, summary_template: str = "{table_name}({columns})"
) -> List[str]:
"""Get db summary for database.
Args:
conn (RDBMSDatabase): database connection
summary_template (str): summary template
"""
tables = conn.get_table_names()
table_info_summaries = [
_parse_table_summary(conn, summary_template, table_name)
for table_name in tables
]
return table_info_summaries
def _parse_table_summary(
conn: RDBMSDatabase, summary_template: str, table_name: str
) -> str:
"""Get table summary for table.
Args:
conn (RDBMSDatabase): database connection
summary_template (str): summary template
table_name (str): table name
Examples:
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
"""
columns = []
for column in conn._inspector.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")
column_str = ", ".join(columns)
index_keys = []
for index_key in conn._inspector.get_indexes(table_name):
key_str = ", ".join(index_key["column_names"])
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name, columns=column_str)
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f", and index keys: {index_key_str}"
try:
comment = conn._inspector.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f", and table comment: {comment.get('text')}"
return table_str

View File

@ -3,16 +3,18 @@ from dbgpt.core.awel import DAG
from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
with DAG("simple_sdk_llm_example_dag") as dag:
prompt = PromptTemplate.from_template(
prompt_task = PromptTemplate.from_template(
"Write a SQL of {dialect} to query all data of {table_name}."
)
req_builder = RequestBuildOperator(model="gpt-3.5-turbo")
llm = OpenAILLM()
out_parser = BaseOutputParser()
prompt >> req_builder >> llm >> out_parser
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
llm_task = OpenAILLM()
out_parse_task = BaseOutputParser()
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task
if __name__ == "__main__":
output = asyncio.run(
out_parser.call(call_data={"data": {"dialect": "mysql", "table_name": "user"}})
out_parse_task.call(
call_data={"data": {"dialect": "mysql", "table_name": "user"}}
)
)
print(f"output: \n\n{output}")

View File

@ -0,0 +1,150 @@
import asyncio
from typing import Dict, List
import json
from dbgpt.core.awel import (
DAG,
InputOperator,
SimpleCallDataInputSource,
JoinOperator,
MapOperator,
)
from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
def _create_temporary_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
def _sql_prompt() -> str:
"""This is a prompt template for SQL generation.
Format of arguments:
{db_name}: database name
{table_info}: table structure information
{dialect}: database dialect
{top_k}: maximum number of results
{user_input}: user question
{response}: response format
Returns:
str: prompt template
"""
return """Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database.
Database name:
{db_name}
Table structure definition:
{table_info}
Constraint:
1.Please understand the user's intention based on the user's question, and use the given table structure definition to create a grammatically correct {dialect} sql. If sql is not required, answer the user's question directly..
2.Always limit the query to a maximum of {top_k} results unless the user specifies in the question the specific number of rows of data he wishes to obtain.
3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will.
4.Please be careful not to mistake the relationship between tables and columns when generating SQL.
5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
User Question:
{user_input}
Please think step by step and respond according to the following JSON format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads.
"""
def _join_func(query_dict: Dict, db_summary: List[str]):
"""Join function for JoinOperator.
Build the format arguments for the prompt template.
Args:
query_dict (Dict): The query dict from DAG input.
db_summary (List[str]): The table structure information from DatasourceRetrieverOperator.
Returns:
Dict: The query dict with the format arguments.
"""
default_response = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
}
response = json.dumps(default_response, ensure_ascii=False, indent=4)
query_dict["table_info"] = db_summary
query_dict["response"] = response
return query_dict
class SQLResultOperator(JoinOperator[Dict]):
"""Merge the SQL result and the model result."""
def __init__(self, **kwargs):
super().__init__(combine_function=self._combine_result, **kwargs)
def _combine_result(self, sql_result_df, model_result: Dict) -> Dict:
model_result["data_df"] = sql_result_df
return model_result
with DAG("simple_sdk_llm_sql_example") as dag:
db_connection = _create_temporary_connection()
input_task = InputOperator(input_source=SimpleCallDataInputSource())
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
# Merge the input data and the table structure information.
prompt_input_task = JoinOperator(combine_function=_join_func)
prompt_task = PromptTemplate.from_template(_sql_prompt())
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
llm_task = OpenAILLM()
out_parse_task = SQLOutputParser()
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
db_query_task = DatasourceOperator(connection=db_connection)
sql_result_task = SQLResultOperator()
input_task >> prompt_input_task
input_task >> retriever_task >> prompt_input_task
(
prompt_input_task
>> prompt_task
>> model_pre_handle_task
>> llm_task
>> out_parse_task
>> sql_parse_task
>> db_query_task
>> sql_result_task
)
out_parse_task >> sql_result_task
if __name__ == "__main__":
input_data = {
"data": {
"db_name": "test_db",
"dialect": "sqlite",
"top_k": 5,
"user_input": "What is the name and age of the user with age less than 18",
}
}
output = asyncio.run(sql_result_task.call(call_data=input_data))
print(f"\nthoughts: {output.get('thoughts')}\n")
print(f"sql: {output.get('sql')}\n")
print(f"result data:\n{output.get('data_df')}")