mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 15:10:14 +00:00
Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework
This commit is contained in:
@@ -2,7 +2,7 @@ import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -12,12 +12,7 @@ from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.configs.config import Config
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
@@ -37,6 +32,9 @@ CHAT_FACTORY = ChatFactory()
|
||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||
knowledge_service = KnowledgeService()
|
||||
|
||||
model_semaphore = None
|
||||
global_counter = 0
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
message = ""
|
||||
@@ -76,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None):
|
||||
)
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
return Result[ConversationVo].succ(dialogues)
|
||||
return Result[ConversationVo].succ(dialogues[-10:][::-1])
|
||||
|
||||
|
||||
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes: List[ChatScene] = [
|
||||
ChatScene.ChatDb,
|
||||
ChatScene.ChatData,
|
||||
ChatScene.ChatWithDbExecute,
|
||||
ChatScene.ChatWithDbQA,
|
||||
ChatScene.ChatDashboard,
|
||||
ChatScene.ChatKnowledge,
|
||||
ChatScene.ChatExecution,
|
||||
@@ -105,7 +103,7 @@ async def dialogue_scenes():
|
||||
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
@@ -139,9 +137,9 @@ def knowledge_list():
|
||||
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
if ChatScene.ChatDb.value == chat_mode:
|
||||
if ChatScene.ChatWithDbQA.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatData.value == chat_mode:
|
||||
elif ChatScene.ChatWithDbExecute.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
@@ -179,6 +177,16 @@ async def dialogue_history_messages(con_uid: str):
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value
|
||||
if not dialogue.conv_uid:
|
||||
dialogue.conv_uid = str(uuid.uuid1())
|
||||
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(
|
||||
@@ -190,99 +198,65 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
"user_input": dialogue.user_input,
|
||||
}
|
||||
|
||||
if ChatScene.ChatDb == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatData == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||
chat_param.update("plugin_selector", dialogue.select_param)
|
||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("knowledge_space", dialogue.select_param)
|
||||
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
|
||||
chat_param.update({"plugin_selector": dialogue.select_param})
|
||||
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
|
||||
chat_param.update({"knowledge_space": dialogue.select_param})
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
headers = {
|
||||
# "Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
# "Transfer-Encoding": "chunked",
|
||||
}
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return non_stream_response(chat)
|
||||
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
|
||||
background=background_tasks)
|
||||
else:
|
||||
# generator = stream_generator(chat)
|
||||
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
||||
# return result
|
||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
|
||||
background=background_tasks)
|
||||
|
||||
|
||||
def stream_test():
|
||||
for message in ["Hello", "world", "how", "are", "you"]:
|
||||
yield message
|
||||
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
def stream_generator(chat):
|
||||
async def no_stream_generator(chat):
|
||||
msg = chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
async def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
# chat.current_message.add_ai_message(msg)
|
||||
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
||||
# json_text = json.dumps(vo.__dict__)
|
||||
# yield json_text.encode('utf-8')
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
for chunk in model_response:
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.1)
|
||||
chat.memory.append(chat.current_message)
|
||||
|
||||
|
||||
# def stream_response(chat):
|
||||
# logger.info("stream out start!")
|
||||
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||
# return api_response
|
||||
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||
return MessageVo(
|
||||
role=message["type"], context=message["data"]["content"], order=order
|
||||
)
|
||||
|
||||
|
||||
def non_stream_response(chat):
|
||||
logger.info("not stream out, wait model response!")
|
||||
return chat.nostream_call()
|
||||
|
||||
|
||||
@router.get("/v1/db/types", response_model=Result[str])
|
||||
async def db_types():
|
||||
return Result.succ(["mysql", "duckdb"])
|
||||
|
||||
|
||||
@router.get("/v1/db/list", response_model=Result[str])
|
||||
async def db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
return Result.succ(dbs)
|
||||
|
||||
|
||||
@router.get("/v1/knowledge/list")
|
||||
async def knowledge_list():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post("/v1/knowledge/add")
|
||||
async def knowledge_add():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post("/v1/knowledge/delete")
|
||||
async def knowledge_delete():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get("/v1/knowledge/types")
|
||||
async def knowledge_types():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get("/v1/knowledge/detail")
|
||||
async def knowledge_detail():
|
||||
return ["test1", "test2"]
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
|
Reference in New Issue
Block a user