WEB API independent

This commit is contained in:
tuyang.yhj
2023-06-29 15:33:14 +08:00
parent 6a3bf33a24
commit 7208dd6c88
10 changed files with 161 additions and 73 deletions

View File

@@ -51,6 +51,39 @@ def __get_conv_user_message(conversations: dict):
return ""
def __new_conversation(chat_mode, user_id) -> ConversationVo:
unique_id = uuid.uuid1()
history_mem = DuckdbHistoryMemory(str(unique_id))
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params: dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
@@ -63,14 +96,15 @@ async def dialogue_list(response: Response, user_id: str = None):
for item in datas:
conv_uid = item.get("conv_uid")
messages = item.get("messages")
conversations = json.loads(messages)
summary = item.get("summary")
chat_mode = item.get("chat_mode")
first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
user_input=summary,
chat_mode=chat_mode,
)
dialogues.append(conv_vo)
@@ -101,39 +135,13 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params: dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
@@ -180,7 +188,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
dialogue.conv_uid = conv_vo.conv_uid
global model_semaphore, global_counter
global_counter += 1
@@ -236,6 +245,7 @@ async def no_stream_generator(chat):
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE: