mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
feat(core): Support simple DB query for sdk (#917)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
parent
43190ca333
commit
cbba50ab1b
@ -266,19 +266,3 @@ class Config(metaclass=Singleton):
|
||||
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_DISK_DIR"
|
||||
)
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_templature(self, value: int) -> None:
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
def set_speak_mode(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.speak_mode = value
|
||||
|
||||
def set_last_plugin_return(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.last_plugin_return = value
|
||||
|
@ -5,13 +5,13 @@ from typing import List
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
ROOT_PATH,
|
||||
)
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from dbgpt.app.base import (
|
||||
@ -30,7 +30,6 @@ from dbgpt.app.knowledge.api import router as knowledge_router
|
||||
from dbgpt.app.prompt.api import router as prompt_router
|
||||
from dbgpt.app.llm_manage.api import router as llm_manage_api
|
||||
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
|
||||
from dbgpt.app.openapi.base import validation_exception_handler
|
||||
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
@ -59,7 +58,7 @@ def swagger_monkey_patch(*args, **kwargs):
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
|
||||
)
|
||||
|
||||
|
||||
@ -79,13 +78,11 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
app.include_router(prompt_router, tags=["Prompt"])
|
||||
|
||||
@ -133,7 +130,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
|
||||
model_name = param.model_name or CFG.LLM_MODEL
|
||||
param.model_name = model_name
|
||||
print(param)
|
||||
|
||||
embedding_model_name = CFG.EMBEDDING_MODEL
|
||||
@ -143,8 +141,6 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
|
||||
model_name = param.model_name or CFG.LLM_MODEL
|
||||
|
||||
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
|
@ -52,6 +52,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
||||
table_infos = await blocking_func_to_async(
|
||||
|
@ -11,7 +11,7 @@ from dbgpt.core.interface.message import (
|
||||
OnceConversation,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import (
|
||||
CacheKey,
|
||||
@ -33,6 +33,7 @@ __ALL__ = [
|
||||
"PromptTemplate",
|
||||
"PromptTemplateOperator",
|
||||
"BaseOutputParser",
|
||||
"SQLOutputParser",
|
||||
"Serializable",
|
||||
"Serializer",
|
||||
"CacheKey",
|
||||
|
@ -53,7 +53,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
return self._data is None
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
@ -251,3 +251,13 @@ def _parse_model_response(response: ResponseTye):
|
||||
else:
|
||||
raise ValueError(f"Unsupported response type {type(response)}")
|
||||
return resp_obj_ex
|
||||
|
||||
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
23
dbgpt/core/interface/retriever.py
Normal file
23
dbgpt/core/interface/retriever.py
Normal file
@ -0,0 +1,23 @@
|
||||
from abc import abstractmethod
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
||||
class RetrieverOperator(MapOperator[IN, OUT]):
|
||||
"""The Abstract Retriever Operator."""
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map input value to output value.
|
||||
|
||||
Args:
|
||||
input_value (IN): The input value.
|
||||
|
||||
Returns:
|
||||
OUT: The output value.
|
||||
"""
|
||||
# The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async.
|
||||
return await self.blocking_func_to_async(self.retrieve, input_value)
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, input_value: IN) -> OUT:
|
||||
"""Retrieve data for input value."""
|
@ -2,58 +2,104 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from abc import ABC
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect(ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get all table names"""
|
||||
pass
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get table info about specified table.
|
||||
|
||||
Returns:
|
||||
str: Table information joined by '\n\n'
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get index info about specified table.
|
||||
|
||||
Args:
|
||||
table_names (Optional[List[str]]): table names
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_example_data(self, table: str, count: int = 3):
|
||||
"""Get example data about specified table.
|
||||
|
||||
Not used now.
|
||||
|
||||
Args:
|
||||
table (str): table name
|
||||
count (int): example data count
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_database_list(self):
|
||||
def get_database_list(self) -> List[str]:
|
||||
"""Get database list.
|
||||
|
||||
Returns:
|
||||
List[str]: database list
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
pass
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
"""Get table comments.
|
||||
|
||||
Args:
|
||||
db_name (str): database name
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute sql command.
|
||||
|
||||
Args:
|
||||
command (str): sql command
|
||||
fetch (str): fetch type
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
"""Execute sql command and return dataframe."""
|
||||
pass
|
||||
|
||||
def get_users(self):
|
||||
pass
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
pass
|
||||
"""Get grant info."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
pass
|
||||
"""Get collation."""
|
||||
return None
|
||||
|
||||
def get_charset(self):
|
||||
pass
|
||||
def get_charset(self) -> str:
|
||||
"""Get character_set of current database."""
|
||||
return "utf-8"
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get the creation table sql about specified table."""
|
||||
pass
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_normal_type(cls) -> bool:
|
||||
"""Return whether the connector is a normal type."""
|
||||
return True
|
||||
|
@ -1,4 +1,4 @@
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class DBConfig(BaseModel):
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import List, Type
|
||||
from dbgpt.datasource import ConnectConfigDao
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
@ -21,7 +22,7 @@ from dbgpt.datasource.rdbms.conn_doris import DorisConnect
|
||||
class ConnectManager:
|
||||
"""db connect manager"""
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += self.get_all_subclasses(subclass)
|
||||
@ -31,7 +32,7 @@ class ConnectManager:
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
support_types = []
|
||||
for cls in chat_classes:
|
||||
if cls.db_type:
|
||||
if cls.db_type and cls.is_normal_type():
|
||||
support_types.append(DBType.of_db_type(cls.db_type))
|
||||
return support_types
|
||||
|
||||
@ -39,7 +40,7 @@ class ConnectManager:
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
result = None
|
||||
for cls in chat_classes:
|
||||
if cls.db_type == db_type:
|
||||
if cls.db_type == db_type and cls.is_normal_type():
|
||||
result = cls
|
||||
if not result:
|
||||
raise ValueError("Unsupported Db Type!" + db_type)
|
||||
|
0
dbgpt/datasource/operator/__init__.py
Normal file
0
dbgpt/datasource/operator/__init__.py
Normal file
16
dbgpt/datasource/operator/datasource_operator.py
Normal file
16
dbgpt/datasource/operator/datasource_operator.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
|
||||
|
||||
class DatasourceOperator(MapOperator[str, Any]):
|
||||
def __init__(self, connection: BaseConnect, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
return await self.blocking_func_to_async(self.query, input_value)
|
||||
|
||||
def query(self, input_value: str) -> Any:
|
||||
return self._connection.run_to_df(input_value)
|
@ -4,9 +4,12 @@
|
||||
import os
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
import tempfile
|
||||
import logging
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLiteConnect(RDBMSDatabase):
|
||||
"""Connect SQLite Database fetch MetaData
|
||||
@ -127,3 +130,116 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
|
||||
class SQLiteTempConnect(SQLiteConnect):
|
||||
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
|
||||
|
||||
def __init__(self, engine, temp_file_path, *args, **kwargs):
|
||||
super().__init__(engine, *args, **kwargs)
|
||||
self.temp_file_path = temp_file_path
|
||||
self._is_closed = False
|
||||
|
||||
@classmethod
|
||||
def create_temporary_db(
|
||||
cls, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> "SQLiteTempConnect":
|
||||
"""Create a temporary SQLite database with a temporary file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
with SQLiteTempConnect.create_temporary_db() as db:
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run(db.session, "insert into test(id) values (1)")
|
||||
db.run(db.session, "insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,), (2,)]
|
||||
|
||||
Args:
|
||||
engine_args (Optional[dict]): SQLAlchemy engine arguments.
|
||||
|
||||
Returns:
|
||||
SQLiteTempConnect: A SQLiteTempConnect instance.
|
||||
"""
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
|
||||
return cls(engine, temp_file_path, **kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
if not self._is_closed:
|
||||
if self._engine:
|
||||
self._engine.dispose()
|
||||
try:
|
||||
if os.path.exists(self.temp_file_path):
|
||||
os.remove(self.temp_file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing temporary database file: {e}")
|
||||
self._is_closed = True
|
||||
|
||||
def create_temp_tables(self, tables_info):
|
||||
"""Create temporary tables with data.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
tables_info = {
|
||||
"test": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 20),
|
||||
(2, "Jack", 21),
|
||||
(3, "Alice", 22),
|
||||
],
|
||||
},
|
||||
}
|
||||
with SQLiteTempConnect.create_temporary_db() as db:
|
||||
db.create_temp_tables(tables_info)
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id", "name", "age"]
|
||||
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
|
||||
|
||||
Args:
|
||||
tables_info (dict): A dictionary of table information.
|
||||
"""
|
||||
for table_name, table_data in tables_info.items():
|
||||
columns = ", ".join(
|
||||
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
|
||||
)
|
||||
create_sql = f"CREATE TABLE {table_name} ({columns});"
|
||||
self.session.execute(text(create_sql))
|
||||
for row in table_data.get("data", []):
|
||||
placeholders = ", ".join(
|
||||
[":param" + str(index) for index, _ in enumerate(row)]
|
||||
)
|
||||
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
|
||||
|
||||
param_dict = {
|
||||
"param" + str(index): value for index, value in enumerate(row)
|
||||
}
|
||||
self.session.execute(text(insert_sql), param_dict)
|
||||
self.session.commit()
|
||||
self._sync_tables_from_db()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def is_normal_type(cls) -> bool:
|
||||
return False
|
||||
|
0
dbgpt/rag/operator/__init__.py
Normal file
0
dbgpt/rag/operator/__init__.py
Normal file
14
dbgpt/rag/operator/datasource.py
Normal file
14
dbgpt/rag/operator/datasource.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
def __init__(self, connection: RDBMSDatabase, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
|
||||
def retrieve(self, input_value: Any) -> Any:
|
||||
summary = _parse_db_summary(self._connection)
|
||||
return summary
|
@ -1,5 +1,7 @@
|
||||
from typing import List
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.rag.summary.db_summary import DBSummary
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -36,32 +38,63 @@ class RdbmsSummary(DBSummary):
|
||||
example:
|
||||
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
|
||||
"""
|
||||
columns = []
|
||||
for column in self.db._inspector.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append((f"{column['name']} ({column.get('comment')})"))
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
|
||||
column_str = ", ".join(columns)
|
||||
index_keys = []
|
||||
for index_key in self.db._inspector.get_indexes(table_name):
|
||||
key_str = ", ".join(index_key["column_names"])
|
||||
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
|
||||
table_str = self.summary_template.format(
|
||||
table_name=table_name, columns=column_str
|
||||
)
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f", and index keys: {index_key_str}"
|
||||
try:
|
||||
comment = self.db._inspector.get_table_comment(table_name)
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f", and table comment: {comment.get('text')}"
|
||||
return table_str
|
||||
return _parse_table_summary(self.db, self.summary_template, table_name)
|
||||
|
||||
def table_summaries(self):
|
||||
"""Get table summaries."""
|
||||
return self.table_info_summaries
|
||||
|
||||
|
||||
def _parse_db_summary(
|
||||
conn: RDBMSDatabase, summary_template: str = "{table_name}({columns})"
|
||||
) -> List[str]:
|
||||
"""Get db summary for database.
|
||||
|
||||
Args:
|
||||
conn (RDBMSDatabase): database connection
|
||||
summary_template (str): summary template
|
||||
"""
|
||||
tables = conn.get_table_names()
|
||||
table_info_summaries = [
|
||||
_parse_table_summary(conn, summary_template, table_name)
|
||||
for table_name in tables
|
||||
]
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _parse_table_summary(
|
||||
conn: RDBMSDatabase, summary_template: str, table_name: str
|
||||
) -> str:
|
||||
"""Get table summary for table.
|
||||
|
||||
Args:
|
||||
conn (RDBMSDatabase): database connection
|
||||
summary_template (str): summary template
|
||||
table_name (str): table name
|
||||
|
||||
Examples:
|
||||
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
|
||||
"""
|
||||
columns = []
|
||||
for column in conn._inspector.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append(f"{column['name']} ({column.get('comment')})")
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
|
||||
column_str = ", ".join(columns)
|
||||
index_keys = []
|
||||
for index_key in conn._inspector.get_indexes(table_name):
|
||||
key_str = ", ".join(index_key["column_names"])
|
||||
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
|
||||
table_str = summary_template.format(table_name=table_name, columns=column_str)
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f", and index keys: {index_key_str}"
|
||||
try:
|
||||
comment = conn._inspector.get_table_comment(table_name)
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f", and table comment: {comment.get('text')}"
|
||||
return table_str
|
||||
|
@ -3,16 +3,18 @@ from dbgpt.core.awel import DAG
|
||||
from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
|
||||
|
||||
with DAG("simple_sdk_llm_example_dag") as dag:
|
||||
prompt = PromptTemplate.from_template(
|
||||
prompt_task = PromptTemplate.from_template(
|
||||
"Write a SQL of {dialect} to query all data of {table_name}."
|
||||
)
|
||||
req_builder = RequestBuildOperator(model="gpt-3.5-turbo")
|
||||
llm = OpenAILLM()
|
||||
out_parser = BaseOutputParser()
|
||||
prompt >> req_builder >> llm >> out_parser
|
||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
||||
llm_task = OpenAILLM()
|
||||
out_parse_task = BaseOutputParser()
|
||||
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task
|
||||
|
||||
if __name__ == "__main__":
|
||||
output = asyncio.run(
|
||||
out_parser.call(call_data={"data": {"dialect": "mysql", "table_name": "user"}})
|
||||
out_parse_task.call(
|
||||
call_data={"data": {"dialect": "mysql", "table_name": "user"}}
|
||||
)
|
||||
)
|
||||
print(f"output: \n\n{output}")
|
||||
|
150
examples/sdk/simple_sdk_llm_sql_example.py
Normal file
150
examples/sdk/simple_sdk_llm_sql_example.py
Normal file
@ -0,0 +1,150 @@
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
import json
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
InputOperator,
|
||||
SimpleCallDataInputSource,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
)
|
||||
from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
def _sql_prompt() -> str:
|
||||
"""This is a prompt template for SQL generation.
|
||||
|
||||
Format of arguments:
|
||||
{db_name}: database name
|
||||
{table_info}: table structure information
|
||||
{dialect}: database dialect
|
||||
{top_k}: maximum number of results
|
||||
{user_input}: user question
|
||||
{response}: response format
|
||||
|
||||
Returns:
|
||||
str: prompt template
|
||||
"""
|
||||
return """Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database.
|
||||
Database name:
|
||||
{db_name}
|
||||
|
||||
Table structure definition:
|
||||
{table_info}
|
||||
|
||||
Constraint:
|
||||
1.Please understand the user's intention based on the user's question, and use the given table structure definition to create a grammatically correct {dialect} sql. If sql is not required, answer the user's question directly..
|
||||
2.Always limit the query to a maximum of {top_k} results unless the user specifies in the question the specific number of rows of data he wishes to obtain.
|
||||
3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will.
|
||||
4.Please be careful not to mistake the relationship between tables and columns when generating SQL.
|
||||
5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
|
||||
|
||||
User Question:
|
||||
{user_input}
|
||||
Please think step by step and respond according to the following JSON format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
|
||||
def _join_func(query_dict: Dict, db_summary: List[str]):
|
||||
"""Join function for JoinOperator.
|
||||
|
||||
Build the format arguments for the prompt template.
|
||||
|
||||
Args:
|
||||
query_dict (Dict): The query dict from DAG input.
|
||||
db_summary (List[str]): The table structure information from DatasourceRetrieverOperator.
|
||||
|
||||
Returns:
|
||||
Dict: The query dict with the format arguments.
|
||||
"""
|
||||
default_response = {
|
||||
"thoughts": "thoughts summary to say to user",
|
||||
"sql": "SQL Query to run",
|
||||
}
|
||||
response = json.dumps(default_response, ensure_ascii=False, indent=4)
|
||||
query_dict["table_info"] = db_summary
|
||||
query_dict["response"] = response
|
||||
return query_dict
|
||||
|
||||
|
||||
class SQLResultOperator(JoinOperator[Dict]):
|
||||
"""Merge the SQL result and the model result."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(combine_function=self._combine_result, **kwargs)
|
||||
|
||||
def _combine_result(self, sql_result_df, model_result: Dict) -> Dict:
|
||||
model_result["data_df"] = sql_result_df
|
||||
return model_result
|
||||
|
||||
|
||||
with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
db_connection = _create_temporary_connection()
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
|
||||
# Merge the input data and the table structure information.
|
||||
prompt_input_task = JoinOperator(combine_function=_join_func)
|
||||
prompt_task = PromptTemplate.from_template(_sql_prompt())
|
||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
||||
llm_task = OpenAILLM()
|
||||
out_parse_task = SQLOutputParser()
|
||||
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
||||
db_query_task = DatasourceOperator(connection=db_connection)
|
||||
sql_result_task = SQLResultOperator()
|
||||
input_task >> prompt_input_task
|
||||
input_task >> retriever_task >> prompt_input_task
|
||||
(
|
||||
prompt_input_task
|
||||
>> prompt_task
|
||||
>> model_pre_handle_task
|
||||
>> llm_task
|
||||
>> out_parse_task
|
||||
>> sql_parse_task
|
||||
>> db_query_task
|
||||
>> sql_result_task
|
||||
)
|
||||
out_parse_task >> sql_result_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_data = {
|
||||
"data": {
|
||||
"db_name": "test_db",
|
||||
"dialect": "sqlite",
|
||||
"top_k": 5,
|
||||
"user_input": "What is the name and age of the user with age less than 18",
|
||||
}
|
||||
}
|
||||
output = asyncio.run(sql_result_task.call(call_data=input_data))
|
||||
print(f"\nthoughts: {output.get('thoughts')}\n")
|
||||
print(f"sql: {output.get('sql')}\n")
|
||||
print(f"result data:\n{output.get('data_df')}")
|
Loading…
Reference in New Issue
Block a user