mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 20:39:44 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/app/openapi/api_v1/__init__.py
Normal file
0
dbgpt/app/openapi/api_v1/__init__.py
Normal file
520
dbgpt/app/openapi/api_v1/api_v1.py
Normal file
520
dbgpt/app/openapi/api_v1/api_v1.py
Normal file
@@ -0,0 +1,520 @@
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
import aiofiles
|
||||
import logging
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
File,
|
||||
UploadFile,
|
||||
Body,
|
||||
Depends,
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Optional
|
||||
from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
DeltaMessage,
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene, ChatFactory
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
from dbgpt.model.base import FlatSupportedModel
|
||||
from dbgpt.util.tracer import root_tracer, SpanType
|
||||
from dbgpt.util.executor_utils import (
|
||||
ExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
DefaultExecutorFactory,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
logger = logging.getLogger(__name__)
|
||||
knowledge_service = KnowledgeService()
|
||||
|
||||
model_semaphore = None
|
||||
global_counter = 0
|
||||
|
||||
|
||||
def __get_conv_user_message(conversations: dict):
|
||||
messages = conversations["messages"]
|
||||
for item in messages:
|
||||
if item["type"] == "human":
|
||||
return item["data"]["content"]
|
||||
return ""
|
||||
|
||||
|
||||
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
|
||||
unique_id = uuid.uuid1()
|
||||
return ConversationVo(
|
||||
conv_uid=str(unique_id),
|
||||
chat_mode=chat_mode,
|
||||
user_name=user_name,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
|
||||
|
||||
def get_db_list():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
db_params = []
|
||||
for item in dbs:
|
||||
params: dict = {}
|
||||
params.update({"param": item["db_name"]})
|
||||
params.update({"type": item["db_type"]})
|
||||
db_params.append(params)
|
||||
return db_params
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
def get_db_list_info():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
params: dict = {}
|
||||
for item in dbs:
|
||||
comment = item["comment"]
|
||||
if comment is not None and len(comment) > 0:
|
||||
params.update({item["db_name"]: comment})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list_info():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
request = KnowledgeSpaceRequest()
|
||||
spaces = knowledge_service.get_knowledge_space(request)
|
||||
for space in spaces:
|
||||
params.update({space.name: space.desc})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list():
|
||||
"""return knowledge space list"""
|
||||
request = KnowledgeSpaceRequest()
|
||||
spaces = knowledge_service.get_knowledge_space(request)
|
||||
space_list = []
|
||||
for space in spaces:
|
||||
params: dict = {}
|
||||
params.update({"param": space.name})
|
||||
params.update({"type": "space"})
|
||||
space_list.append(params)
|
||||
return space_list
|
||||
|
||||
|
||||
def get_model_controller() -> BaseModelController:
|
||||
controller = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||
)
|
||||
return controller
|
||||
|
||||
|
||||
def get_worker_manager() -> WorkerManager:
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return worker_manager
|
||||
|
||||
|
||||
def get_executor() -> Executor:
|
||||
"""Get the global default executor"""
|
||||
return CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT,
|
||||
ExecutorFactory,
|
||||
or_register_component=DefaultExecutorFactory,
|
||||
).create()
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||
async def db_connect_list():
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/add", response_model=Result[bool])
|
||||
async def db_connect_add(db_config: DBConfig = Body()):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/edit", response_model=Result[bool])
|
||||
async def db_connect_edit(db_config: DBConfig = Body()):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/delete", response_model=Result[bool])
|
||||
async def db_connect_delete(db_name: str = None):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
|
||||
|
||||
|
||||
async def async_db_summary_embedding(db_name, db_type):
|
||||
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
db_summary_client.db_summary_embedding(db_name, db_type)
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/test/connect", response_model=Result[bool])
|
||||
async def test_connect(db_config: DBConfig = Body()):
|
||||
try:
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
|
||||
return Result.succ(True)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E1001", msg=str(e))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/summary", response_model=Result[bool])
|
||||
async def db_summary(db_name: str, db_type: str):
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
async_db_summary_embedding(db_name, db_type)
|
||||
return Result.succ(True)
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
|
||||
async def db_support_types():
|
||||
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
|
||||
db_type_infos = []
|
||||
for type in support_types:
|
||||
db_type_infos.append(
|
||||
DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
|
||||
)
|
||||
return Result[DbTypeInfo].succ(db_type_infos)
|
||||
|
||||
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(
|
||||
user_name: str = None, user_id: str = None, sys_code: str = None
|
||||
):
|
||||
dialogues: List = []
|
||||
chat_history_service = ChatHistory()
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
user_name = user_name or user_id
|
||||
datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
|
||||
for item in datas:
|
||||
conv_uid = item.get("conv_uid")
|
||||
summary = item.get("summary")
|
||||
chat_mode = item.get("chat_mode")
|
||||
model_name = item.get("model_name", CFG.LLM_MODEL)
|
||||
user_name = item.get("user_name")
|
||||
sys_code = item.get("sys_code")
|
||||
|
||||
messages = json.loads(item.get("messages"))
|
||||
last_round = max(messages, key=lambda x: x["chat_order"])
|
||||
if "param_value" in last_round:
|
||||
select_param = last_round["param_value"]
|
||||
else:
|
||||
select_param = ""
|
||||
conv_vo: ConversationVo = ConversationVo(
|
||||
conv_uid=conv_uid,
|
||||
user_input=summary,
|
||||
chat_mode=chat_mode,
|
||||
model_name=model_name,
|
||||
select_param=select_param,
|
||||
user_name=user_name,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
return Result[ConversationVo].succ(dialogues[:10])
|
||||
|
||||
|
||||
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes: List[ChatScene] = [
|
||||
ChatScene.ChatWithDbExecute,
|
||||
ChatScene.ChatWithDbQA,
|
||||
ChatScene.ChatExcel,
|
||||
ChatScene.ChatKnowledge,
|
||||
ChatScene.ChatDashboard,
|
||||
ChatScene.ChatAgent,
|
||||
]
|
||||
for scene in new_modes:
|
||||
scene_vo = ChatSceneVo(
|
||||
chat_scene=scene.value(),
|
||||
scene_name=scene.scene_name(),
|
||||
scene_describe=scene.describe(),
|
||||
param_title=",".join(scene.param_types()),
|
||||
show_disable=scene.show_disable(),
|
||||
)
|
||||
scene_vos.append(scene_vo)
|
||||
return Result.succ(scene_vos)
|
||||
|
||||
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value(),
|
||||
user_name: str = None,
|
||||
# TODO remove user id
|
||||
user_id: str = None,
|
||||
sys_code: str = None,
|
||||
):
|
||||
user_name = user_name or user_id
|
||||
conv_vo = __new_conversation(chat_mode, user_name, sys_code)
|
||||
return Result.succ(conv_vo)
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
if ChatScene.ChatWithDbQA.value() == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatWithDbExecute.value() == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatDashboard.value() == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatExecution.value() == chat_mode:
|
||||
return Result.succ(plugins_select_info())
|
||||
elif ChatScene.ChatKnowledge.value() == chat_mode:
|
||||
return Result.succ(knowledge_list())
|
||||
elif ChatScene.ChatKnowledge.ExtractRefineSummary.value() == chat_mode:
|
||||
return Result.succ(knowledge_list())
|
||||
else:
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/file/load")
|
||||
async def params_load(
|
||||
conv_uid: str,
|
||||
chat_mode: str,
|
||||
model_name: str,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
# Save the uploaded file
|
||||
upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
upload_path = os.path.join(upload_dir, doc_file.filename)
|
||||
async with aiofiles.open(upload_path, "wb") as f:
|
||||
await f.write(await doc_file.read())
|
||||
|
||||
# Prepare the chat
|
||||
dialogue = ConversationVo(
|
||||
conv_uid=conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
select_param=doc_file.filename,
|
||||
model_name=model_name,
|
||||
user_name=user_name,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
chat: BaseChat = await get_chat_instance(dialogue)
|
||||
resp = await chat.prepare()
|
||||
|
||||
# Refresh messages
|
||||
return Result.succ(get_hist_messages(conv_uid))
|
||||
except Exception as e:
|
||||
logger.error("excel load error!", e)
|
||||
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
|
||||
|
||||
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_fac = ChatHistory()
|
||||
history_mem = history_fac.get_store_instance(con_uid)
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
def get_hist_messages(conv_uid: str):
|
||||
message_vos: List[MessageVo] = []
|
||||
history_fac = ChatHistory()
|
||||
history_mem = history_fac.get_store_instance(conv_uid)
|
||||
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
model_name = once.get("model_name", CFG.LLM_MODEL)
|
||||
once_message_vos = [
|
||||
message2Vo(element, once["chat_order"], model_name)
|
||||
for element in once["messages"]
|
||||
]
|
||||
message_vos.extend(once_message_vos)
|
||||
return message_vos
|
||||
|
||||
|
||||
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
||||
async def dialogue_history_messages(con_uid: str):
|
||||
print(f"dialogue_history_messages:{con_uid}")
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
return Result.succ(get_hist_messages(con_uid))
|
||||
|
||||
|
||||
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||
logger.info(f"get_chat_instance:{dialogue}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||
if not dialogue.conv_uid:
|
||||
conv_vo = __new_conversation(
|
||||
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
|
||||
)
|
||||
dialogue.conv_uid = conv_vo.conv_uid
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(
|
||||
Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!")
|
||||
)
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_name": dialogue.user_name,
|
||||
"sys_code": dialogue.sys_code,
|
||||
"current_user_input": dialogue.user_input,
|
||||
"select_param": dialogue.select_param,
|
||||
"model_name": dialogue.model_name,
|
||||
}
|
||||
chat: BaseChat = await blocking_func_to_async(
|
||||
get_executor(),
|
||||
CHAT_FACTORY.get_implementation,
|
||||
dialogue.chat_mode,
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
||||
|
||||
@router.post("/v1/chat/prepare")
|
||||
async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||
# dialogue.model_name = CFG.LLM_MODEL
|
||||
logger.info(f"chat_prepare:{dialogue}")
|
||||
## check conv_uid
|
||||
chat: BaseChat = await get_chat_instance(dialogue)
|
||||
if len(chat.history_message) > 0:
|
||||
return Result.succ(None)
|
||||
resp = await chat.prepare()
|
||||
return Result.succ(resp)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(
|
||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(dialogue)
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return StreamingResponse(
|
||||
no_stream_generator(chat),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/model/types")
|
||||
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
|
||||
logger.info(f"/controller/model/types")
|
||||
try:
|
||||
types = set()
|
||||
models = await controller.get_all_instances(healthy_only=True)
|
||||
for model in models:
|
||||
worker_name, worker_type = model.model_name.split("@")
|
||||
if worker_type == "llm":
|
||||
types.add(worker_name)
|
||||
return Result.succ(list(types))
|
||||
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"controller model types error {e}")
|
||||
|
||||
|
||||
@router.get("/v1/model/supports")
|
||||
async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||
logger.info(f"/controller/model/supports")
|
||||
try:
|
||||
models = await worker_manager.supported_models()
|
||||
return Result.succ(FlatSupportedModel.from_supports(models))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"Fetch supportd models error {e}")
|
||||
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
with root_tracer.start_span("no_stream_generator"):
|
||||
msg = await chat.nostream_call()
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
"""Generate streaming responses
|
||||
|
||||
Our goal is to generate an openai-compatible streaming responses.
|
||||
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||
|
||||
Args:
|
||||
chat (BaseChat): Chat instance.
|
||||
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||
model_name (str): The model name
|
||||
|
||||
Yields:
|
||||
_type_: streaming responses
|
||||
"""
|
||||
span = root_tracer.start_span("stream_generator")
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
|
||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||
previous_response = ""
|
||||
async for chunk in chat.stream_call():
|
||||
if chunk:
|
||||
msg = chunk.replace("\ufffd", "")
|
||||
if incremental:
|
||||
incremental_output = msg[len(previous_response) :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=stream_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# TODO generate an openai-compatible streaming responses
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
previous_response = msg
|
||||
await asyncio.sleep(0.02)
|
||||
if incremental:
|
||||
yield "data: [DONE]\n\n"
|
||||
span.end()
|
||||
|
||||
|
||||
def message2Vo(message: dict, order, model_name) -> MessageVo:
|
||||
return MessageVo(
|
||||
role=message["type"],
|
||||
context=message["data"]["content"],
|
||||
order=order,
|
||||
model_name=model_name,
|
||||
)
|
0
dbgpt/app/openapi/api_v1/editor/__init__.py
Normal file
0
dbgpt/app/openapi/api_v1/editor/__init__.py
Normal file
346
dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Normal file
346
dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import json
|
||||
import time
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
from dbgpt.app.scene import ChatFactory
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
from dbgpt.app.openapi.editor_view_model import (
|
||||
ChatDbRounds,
|
||||
ChartList,
|
||||
ChartDetail,
|
||||
ChatChartEditContext,
|
||||
ChatSqlEditContext,
|
||||
DbTable,
|
||||
)
|
||||
|
||||
from dbgpt.app.openapi.api_v1.editor.sql_editor import (
|
||||
DataNode,
|
||||
ChartRunData,
|
||||
SqlRunData,
|
||||
)
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
from dbgpt.app.scene.chat_db.data_loader import DbDataLoader
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
async def get_editor_tables(
|
||||
db_name: str, page_index: int, page_size: int, search_str: str = ""
|
||||
):
|
||||
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
tables = db_conn.get_table_names()
|
||||
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||
for table in tables:
|
||||
table_node: DataNode = DataNode(title=table, key=table, type="table")
|
||||
db_node.children.append(table_node)
|
||||
fields = db_conn.get_fields(table)
|
||||
for field in fields:
|
||||
table_node.children.append(
|
||||
DataNode(
|
||||
title=field[0],
|
||||
key=field[0],
|
||||
type=field[1],
|
||||
default_value=field[2],
|
||||
can_null=field[3] or "YES",
|
||||
comment=str(field[-1]),
|
||||
)
|
||||
)
|
||||
|
||||
return Result.succ(db_node)
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
result: List = []
|
||||
for once in history_messages:
|
||||
round_name: str = ""
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "human":
|
||||
round_name = element["data"]["content"]
|
||||
if once.get("param_value"):
|
||||
round: ChatDbRounds = ChatDbRounds(
|
||||
round=once["chat_order"],
|
||||
db_name=once["param_value"],
|
||||
round_name=round_name,
|
||||
)
|
||||
result.append(round)
|
||||
return Result.succ(result)
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
async def get_editor_sql(con_uid: str, round: int):
|
||||
logger.info(f"get_editor_sql:{con_uid},{round}")
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
if int(once["chat_order"]) == round:
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "ai":
|
||||
logger.info(
|
||||
f'history ai json resp:{element["data"]["content"]}'
|
||||
)
|
||||
context = (
|
||||
element["data"]["content"]
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
)
|
||||
return Result.succ(json.loads(context))
|
||||
return Result.failed(msg="not have sql!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||
async def editor_sql_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_sql_run:{run_param}")
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
if not db_name and not sql:
|
||||
return Result.failed("SQL run param error!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
try:
|
||||
start_time = time.time() * 1000
|
||||
colunms, sql_result = conn.query_ex(sql)
|
||||
# 转换结果类型
|
||||
sql_result = [tuple(x) for x in sql_result]
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
return Result.succ(sql_run_data)
|
||||
except Exception as e:
|
||||
logging.error("editor_sql_run exception!" + str(e))
|
||||
return Result.succ(
|
||||
SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit")
|
||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(sql_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
|
||||
edit_round = list(
|
||||
filter(
|
||||
lambda x: x["chat_order"] == sql_edit_context.conv_round,
|
||||
history_messages,
|
||||
)
|
||||
)[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
db_resp = json.loads(element["data"]["content"])
|
||||
db_resp["thoughts"] = sql_edit_context.new_speak
|
||||
db_resp["sql"] = sql_edit_context.new_sql
|
||||
element["data"]["content"] = json.dumps(db_resp)
|
||||
if element["type"] == "view":
|
||||
data_loader = DbDataLoader()
|
||||
element["data"]["content"] = data_loader.get_table_view_by_conn(
|
||||
conn.run_to_df(sql_edit_context.new_sql),
|
||||
sql_edit_context.new_speak,
|
||||
)
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.failed(msg="Edit Failed!")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
logger.info(
|
||||
f"get_editor_sql_rounds:{con_uid}",
|
||||
)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
chart_list: ChartList = ChartList(
|
||||
round=last_round["chat_order"],
|
||||
db_name=db_name,
|
||||
charts=json.loads(element["data"]["content"]),
|
||||
)
|
||||
return Result.succ(chart_list)
|
||||
return Result.failed(msg="Not have charts!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(param: dict = Body()):
|
||||
logger.info(f"get_editor_chart_info:{param}")
|
||||
conv_uid = param["con_uid"]
|
||||
chart_title = param["chart_title"]
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
if not db_name:
|
||||
logger.error(
|
||||
"this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
return Result.failed(
|
||||
msg="this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"])
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(
|
||||
filter(lambda x: x["chart_name"] == chart_title, charts)
|
||||
)[0]
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(
|
||||
chart_uid=find_chart["chart_uid"],
|
||||
chart_type=find_chart["chart_type"],
|
||||
chart_desc=find_chart["chart_desc"],
|
||||
chart_sql=find_chart["chart_sql"],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart["chart_name"],
|
||||
chart_value=find_chart["values"],
|
||||
table_value=conn.run(find_chart["chart_sql"]),
|
||||
)
|
||||
|
||||
return Result.succ(detail)
|
||||
return Result.failed(msg="Can't Find Chart Detail Info!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_chart_run:{run_param}")
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
chart_type = run_param["chart_type"]
|
||||
if not db_name and not sql:
|
||||
return Result.failed("SQL run param error!")
|
||||
try:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
colunms, sql_result = db_conn.query_ex(sql)
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
|
||||
colunms, sql_result, sql
|
||||
)
|
||||
|
||||
start_time = time.time() * 1000
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
return Result.succ(
|
||||
ChartRunData(
|
||||
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
||||
return Result.succ(
|
||||
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
|
||||
|
||||
edit_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
if edit_round:
|
||||
try:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"])
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(
|
||||
filter(
|
||||
lambda x: x["chart_name"]
|
||||
== chart_edit_context.chart_title,
|
||||
charts,
|
||||
)
|
||||
)[0]
|
||||
if chart_edit_context.new_chart_type:
|
||||
find_chart["chart_type"] = chart_edit_context.new_chart_type
|
||||
if chart_edit_context.new_comment:
|
||||
find_chart["chart_desc"] = chart_edit_context.new_comment
|
||||
|
||||
(
|
||||
field_names,
|
||||
chart_values,
|
||||
) = dashboard_data_loader.get_chart_values_by_conn(
|
||||
db_conn, chart_edit_context.new_sql
|
||||
)
|
||||
find_chart["chart_sql"] = chart_edit_context.new_sql
|
||||
find_chart["values"] = [value.dict() for value in chart_values]
|
||||
find_chart["column_name"] = field_names
|
||||
|
||||
element["data"]["content"] = json.dumps(
|
||||
view_data, ensure_ascii=False
|
||||
)
|
||||
if element["type"] == "ai":
|
||||
ai_resp: dict = json.loads(element["data"]["content"])
|
||||
edit_item = list(
|
||||
filter(
|
||||
lambda x: x["title"] == chart_edit_context.chart_title,
|
||||
ai_resp,
|
||||
)
|
||||
)[0]
|
||||
|
||||
edit_item["sql"] = chart_edit_context.new_sql
|
||||
edit_item["showcase"] = chart_edit_context.new_chart_type
|
||||
edit_item["thoughts"] = chart_edit_context.new_comment
|
||||
element["data"]["content"] = json.dumps(
|
||||
ai_resp, ensure_ascii=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"edit chart exception!{str(e)}", e)
|
||||
return Result.failed(msg=f"Edit chart exception!{str(e)}")
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.failed(msg="Edit Failed!")
|
27
dbgpt/app/openapi/api_v1/editor/sql_editor.py
Normal file
27
dbgpt/app/openapi/api_v1/editor/sql_editor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import List
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
|
||||
|
||||
class DataNode(BaseModel):
|
||||
title: str
|
||||
key: str
|
||||
|
||||
type: str = ""
|
||||
default_value: str = None
|
||||
can_null: str = "YES"
|
||||
comment: str = None
|
||||
children: List = []
|
||||
|
||||
|
||||
class SqlRunData(BaseModel):
|
||||
result_info: str
|
||||
run_cost: str
|
||||
colunms: List[str]
|
||||
values: List
|
||||
|
||||
|
||||
class ChartRunData(BaseModel):
|
||||
sql_data: SqlRunData
|
||||
chart_values: List[ValueItem]
|
||||
chart_type: str
|
0
dbgpt/app/openapi/api_v1/feedback/__init__.py
Normal file
0
dbgpt/app/openapi/api_v1/feedback/__init__.py
Normal file
47
dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py
Normal file
47
dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from fastapi import APIRouter, Body, Request
|
||||
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import (
|
||||
ChatFeedBackDao,
|
||||
)
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
|
||||
router = APIRouter()
|
||||
chat_feed_back = ChatFeedBackDao()
|
||||
|
||||
|
||||
@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
|
||||
async def feed_back_find(conv_uid: str, conv_index: int):
|
||||
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
|
||||
if rt is not None:
|
||||
return Result.succ(
|
||||
FeedBackBody(
|
||||
conv_uid=rt.conv_uid,
|
||||
conv_index=rt.conv_index,
|
||||
question=rt.question,
|
||||
knowledge_space=rt.knowledge_space,
|
||||
score=rt.score,
|
||||
ques_type=rt.ques_type,
|
||||
messages=rt.messages,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/feedback/commit", response_model=Result[bool])
|
||||
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
|
||||
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
|
||||
return Result.succ(True)
|
||||
|
||||
|
||||
@router.get("/v1/feedback/select", response_model=Result[dict])
|
||||
async def feed_back_select():
|
||||
return Result.succ(
|
||||
{
|
||||
"information": "信息查询",
|
||||
"work_study": "工作学习",
|
||||
"just_fun": "互动闲聊",
|
||||
"others": "其他",
|
||||
}
|
||||
)
|
95
dbgpt/app/openapi/api_v1/feedback/feed_back_db.py
Normal file
95
dbgpt/app/openapi/api_v1/feedback/feed_back_db.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
|
||||
class ChatFeedBackEntity(Base):
|
||||
__tablename__ = "chat_feed_back"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
conv_uid = Column(String(128))
|
||||
conv_index = Column(Integer)
|
||||
score = Column(Integer)
|
||||
ques_type = Column(String(32))
|
||||
question = Column(Text)
|
||||
knowledge_space = Column(String(128))
|
||||
messages = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
|
||||
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
|
||||
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ChatFeedBackDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||
# Todo: We need to have user information first.
|
||||
|
||||
session = self.get_session()
|
||||
chat_feed_back = ChatFeedBackEntity(
|
||||
conv_uid=feed_back.conv_uid,
|
||||
conv_index=feed_back.conv_index,
|
||||
score=feed_back.score,
|
||||
ques_type=feed_back.ques_type,
|
||||
question=feed_back.question,
|
||||
knowledge_space=feed_back.knowledge_space,
|
||||
messages=feed_back.messages,
|
||||
user_name=feed_back.user_name,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
|
||||
.first()
|
||||
)
|
||||
if result is not None:
|
||||
result.score = feed_back.score
|
||||
result.ques_type = feed_back.ques_type
|
||||
result.question = feed_back.question
|
||||
result.knowledge_space = feed_back.knowledge_space
|
||||
result.messages = feed_back.messages
|
||||
result.user_name = feed_back.user_name
|
||||
result.gmt_created = datetime.now()
|
||||
result.gmt_modified = datetime.now()
|
||||
else:
|
||||
session.merge(chat_feed_back)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||
session = self.get_session()
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == conv_index)
|
||||
.first()
|
||||
)
|
||||
session.close()
|
||||
return result
|
28
dbgpt/app/openapi/api_v1/feedback/feed_back_model.py
Normal file
28
dbgpt/app/openapi/api_v1/feedback/feed_back_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class FeedBackBody(BaseModel):
|
||||
"""conv_uid: conversation id"""
|
||||
|
||||
conv_uid: str
|
||||
|
||||
"""conv_index: conversation index"""
|
||||
conv_index: int
|
||||
|
||||
"""question: human question"""
|
||||
question: str
|
||||
|
||||
"""score: rating of the llm's answer"""
|
||||
score: int
|
||||
|
||||
"""ques_type: question type"""
|
||||
ques_type: str
|
||||
|
||||
user_name: Optional[str] = None
|
||||
|
||||
"""messages: rating detail"""
|
||||
messages: Optional[str] = None
|
||||
|
||||
"""knowledge_space: knowledge space"""
|
||||
knowledge_space: Optional[str] = None
|
Reference in New Issue
Block a user