Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework

This commit is contained in:
aries_ckt
2023-06-29 13:35:03 +08:00
21 changed files with 375 additions and 315 deletions

View File

@@ -2,7 +2,7 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
@@ -12,12 +12,7 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.server.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@@ -37,6 +32,9 @@ CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
@@ -76,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None):
)
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues)
return Result[ConversationVo].succ(dialogues[-10:][::-1])
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes: List[ChatScene] = [
ChatScene.ChatDb,
ChatScene.ChatData,
ChatScene.ChatWithDbExecute,
ChatScene.ChatWithDbQA,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
@@ -105,7 +103,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@@ -139,9 +137,9 @@ def knowledge_list():
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode:
if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatData.value == chat_mode:
elif ChatScene.ChatWithDbExecute.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list())
@@ -179,6 +177,16 @@ async def dialogue_history_messages(con_uid: str):
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(
@@ -190,99 +198,65 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"user_input": dialogue.user_input,
}
if ChatScene.ChatDb == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatData == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatDashboard == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatExecution == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param)
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
chat_param.update("knowledge_space", dialogue.select_param)
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
chat_param.update({"knowledge_space": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
headers = {
# "Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# "Transfer-Encoding": "chunked",
}
if not chat.prompt_template.stream_out:
return non_stream_response(chat)
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
else:
# generator = stream_generator(chat)
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
# return result
return StreamingResponse(stream_generator(chat), media_type="text/plain")
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
background=background_tasks)
def stream_test():
for message in ["Hello", "world", "how", "are", "you"]:
yield message
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
def release_model_semaphore():
model_semaphore.release()
def stream_generator(chat):
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
yield msg
# chat.current_message.add_ai_message(msg)
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
# json_text = json.dumps(vo.__dict__)
# yield json_text.encode('utf-8')
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.memory.append(chat.current_message)
# def stream_response(chat):
# logger.info("stream out start!")
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
# return api_response
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
return chat.nostream_call()
@router.get("/v1/db/types", response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get("/v1/db/list", response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get("/v1/knowledge/list")
async def knowledge_list():
return ["test1", "test2"]
@router.post("/v1/knowledge/add")
async def knowledge_add():
return ["test1", "test2"]
@router.post("/v1/knowledge/delete")
async def knowledge_delete():
return ["test1", "test2"]
@router.get("/v1/knowledge/types")
async def knowledge_types():
return ["test1", "test2"]
@router.get("/v1/knowledge/detail")
async def knowledge_detail():
return ["test1", "test2"]
return MessageVo(role=message['type'], context=message['data']['content'], order=order)