Files
DB-GPT/pilot/openapi/api_v1/api_v1.py
2023-06-29 20:41:05 +08:00

277 lines
9.5 KiB
Python

import uuid
import json
import asyncio
import time
import os
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
static_file_path = os.path.join(os.getcwd(), "server/static")
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(msg=message)
def __get_conv_user_message(conversations: dict):
messages = conversations["messages"]
for item in messages:
if item["type"] == "human":
return item["data"]["content"]
return ""
def __new_conversation(chat_mode, user_id) -> ConversationVo:
unique_id = uuid.uuid1()
history_mem = DuckdbHistoryMemory(str(unique_id))
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params: dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
@router.get("/")
async def read_main():
return FileResponse(f"{static_file_path}/test.html")
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET"
response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
for item in datas:
conv_uid = item.get("conv_uid")
summary = item.get("summary")
chat_mode = item.get("chat_mode")
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=summary,
chat_mode=chat_mode,
)
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues[-10:][::-1])
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes: List[ChatScene] = [
ChatScene.ChatWithDbExecute,
ChatScene.ChatWithDbQA,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
]
for scene in new_modes:
if not scene.value in [
ChatScene.ChatNormal.value,
ChatScene.InnerChatDBSummary.value,
]:
scene_vo = ChatSceneVo(
chat_scene=scene.value,
scene_name=scene.name,
param_title="Selection Param",
)
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatWithDbExecute.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatExecution.value == chat_mode:
return Result.succ(plugins_select_info())
elif ChatScene.ChatKnowledge.value == chat_mode:
return Result.succ(knowledge_list())
else:
return Result.succ(None)
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete()
return Result.succ(None)
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
]
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
dialogue.conv_uid = conv_vo.conv_uid
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
)
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
}
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
chat_param.update({"knowledge_space": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
headers = {
# "Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# "Transfer-Encoding": "chunked",
}
if not chat.prompt_template.stream_out:
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
else:
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
background=background_tasks)
def release_model_semaphore():
model_semaphore.release()
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)
def message2Vo(message: dict, order) -> MessageVo:
return MessageVo(role=message['type'], context=message['data']['content'], order=order)