Files
DB-GPT/dbgpt/app/openapi/api_v1/api_v1.py
明天 d5afa6e206 Native data AI application framework based on AWEL+AGENT (#1152)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com>
Co-authored-by: licunxing <864255598@qq.com>
Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: xuyuan23 <643854343@qq.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: hzh97 <2976151305@qq.com>
2024-02-07 17:43:27 +08:00

505 lines
17 KiB
Python

import asyncio
import logging
import os
import uuid
from concurrent.futures import Executor
from typing import List, Optional
import aiofiles
from fastapi import APIRouter, Body, Depends, File, UploadFile
from fastapi.responses import StreamingResponse
from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.openapi.api_view_model import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatSceneVo,
ConversationVo,
DeltaMessage,
MessageVo,
Result,
)
from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
from dbgpt.model.base import FlatSupportedModel
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
from dbgpt.serve.agent.agents.controller import multi_agents
from dbgpt.serve.flow.service.service import Service as FlowService
from dbgpt.util.executor_utils import (
DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
)
from dbgpt.util.tracer import SpanType, root_tracer
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = logging.getLogger(__name__)
knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
def __get_conv_user_message(conversations: dict):
messages = conversations["messages"]
for item in messages:
if item["type"] == "human":
return item["data"]["content"]
return ""
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
unique_id = uuid.uuid1()
return ConversationVo(
conv_uid=str(unique_id),
chat_mode=chat_mode,
user_name=user_name,
sys_code=sys_code,
)
def get_db_list():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
db_params = []
for item in dbs:
params: dict = {}
params.update({"param": item["db_name"]})
params.update({"type": item["db_type"]})
db_params.append(params)
return db_params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
if comment is not None and len(comment) > 0:
params.update({item["db_name"]: comment})
return params
def knowledge_list_info():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.desc})
return params
def knowledge_list():
"""return knowledge space list"""
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
space_list = []
for space in spaces:
params: dict = {}
params.update({"param": space.name})
params.update({"type": "space"})
space_list.append(params)
return space_list
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager
def get_chat_flow() -> FlowService:
"""Get Chat Flow Service."""
return FlowService.get_instance(CFG.SYSTEM_APP)
def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT,
ExecutorFactory,
or_register_component=DefaultExecutorFactory,
).create()
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@router.post("/v1/chat/db/add", response_model=Result[bool])
async def db_connect_add(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
@router.post("/v1/chat/db/edit", response_model=Result[bool])
async def db_connect_edit(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
@router.post("/v1/chat/db/delete", response_model=Result[bool])
async def db_connect_delete(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
async def async_db_summary_embedding(db_name, db_type):
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type)
@router.post("/v1/chat/db/test/connect", response_model=Result[bool])
async def test_connect(db_config: DBConfig = Body()):
try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
return Result.succ(True)
except Exception as e:
return Result.failed(code="E1001", msg=str(e))
@router.post("/v1/chat/db/summary", response_model=Result[bool])
async def db_summary(db_name: str, db_type: str):
# TODO Change the synchronous call to the asynchronous call
async_db_summary_embedding(db_name, db_type)
return Result.succ(True)
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
db_type_infos = []
for type in support_types:
db_type_infos.append(
DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
)
return Result[DbTypeInfo].succ(db_type_infos)
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes: List[ChatScene] = [
ChatScene.ChatWithDbExecute,
ChatScene.ChatWithDbQA,
ChatScene.ChatExcel,
ChatScene.ChatKnowledge,
ChatScene.ChatDashboard,
ChatScene.ChatAgent,
]
for scene in new_modes:
scene_vo = ChatSceneVo(
chat_scene=scene.value(),
scene_name=scene.scene_name(),
scene_describe=scene.describe(),
param_title=",".join(scene.param_types()),
show_disable=scene.show_disable(),
)
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
if ChatScene.ChatWithDbQA.value() == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatWithDbExecute.value() == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value() == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatExecution.value() == chat_mode:
return Result.succ(plugins_select_info())
elif ChatScene.ChatKnowledge.value() == chat_mode:
return Result.succ(knowledge_list())
elif ChatScene.ChatKnowledge.ExtractRefineSummary.value() == chat_mode:
return Result.succ(knowledge_list())
else:
return Result.succ(None)
@router.post("/v1/chat/mode/params/file/load")
async def params_load(
conv_uid: str,
chat_mode: str,
model_name: str,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
# Save the uploaded file
upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
os.makedirs(upload_dir, exist_ok=True)
upload_path = os.path.join(upload_dir, doc_file.filename)
async with aiofiles.open(upload_path, "wb") as f:
await f.write(await doc_file.read())
# Prepare the chat
dialogue = ConversationVo(
conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
user_name=user_name,
sys_code=sys_code,
)
chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare()
# Refresh messages
return Result.succ(get_hist_messages(conv_uid))
except Exception as e:
logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
def get_hist_messages(conv_uid: str):
from dbgpt.serve.conversation.serve import Service as ConversationService
instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
return instance.get_history_messages({"conv_uid": conv_uid})
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
if not dialogue.conv_uid:
conv_vo = __new_conversation(
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
)
dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_name": dialogue.user_name,
"sys_code": dialogue.sys_code,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name,
}
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
dialogue.chat_mode,
**{"chat_param": chat_param},
)
return chat
@router.post("/v1/chat/prepare")
async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = await get_chat_instance(dialogue)
if chat.has_history_messages():
return Result.succ(None)
resp = await chat.prepare()
return Result.succ(resp)
@router.post("/v1/chat/completions")
async def chat_completions(
dialogue: ConversationVo = Body(),
flow_service: FlowService = Depends(get_chat_flow),
):
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
if dialogue.chat_mode == ChatScene.ChatAgent.value():
return StreamingResponse(
multi_agents.app_agent_chat(
conv_uid=dialogue.conv_uid,
gpts_name=dialogue.select_param,
user_query=dialogue.user_input,
user_code=dialogue.user_name,
sys_code=dialogue.sys_code,
),
headers=headers,
media_type="text/event-stream",
)
elif dialogue.chat_mode == ChatScene.ChatFlow.value():
flow_ctx = CommonLLMHTTPRequestContext(
conv_uid=dialogue.conv_uid,
chat_mode=dialogue.chat_mode,
user_name=dialogue.user_name,
sys_code=dialogue.sys_code,
)
flow_req = CommonLLMHttpRequestBody(
model=dialogue.model_name,
messages=dialogue.user_input,
stream=True,
context=flow_ctx,
)
return StreamingResponse(
flow_stream_generator(
flow_service.chat_flow(dialogue.select_param, flow_req),
dialogue.incremental,
dialogue.model_name,
),
headers=headers,
media_type="text/event-stream",
)
else:
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = await get_chat_instance(dialogue)
if not chat.prompt_template.stream_out:
return StreamingResponse(
no_stream_generator(chat),
headers=headers,
media_type="text/event-stream",
)
else:
return StreamingResponse(
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
@router.get("/v1/model/types")
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
logger.info(f"/controller/model/types")
try:
types = set()
models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm":
types.add(worker_name)
return Result.succ(list(types))
except Exception as e:
return Result.failed(code="E000X", msg=f"controller model types error {e}")
@router.get("/v1/model/supports")
async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()
return Result.succ(FlatSupportedModel.from_supports(models))
except Exception as e:
return Result.failed(code="E000X", msg=f"Fetch supportd models error {e}")
async def no_stream_generator(chat):
with root_tracer.start_span("no_stream_generator"):
msg = await chat.nostream_call()
yield f"data: {msg}\n\n"
async def flow_stream_generator(func, incremental: bool, model_name: str):
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in func:
if chunk:
msg = chunk.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.
Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name
Yields:
_type_: streaming responses
"""
span = root_tracer.start_span("stream_generator")
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chunk.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
span.end()
def message2Vo(message: dict, order, model_name) -> MessageVo:
return MessageVo(
role=message["type"],
context=message["data"]["content"],
order=order,
model_name=model_name,
)