WEB API independent

This commit is contained in:
tuyang.yhj
2023-06-29 09:55:43 +08:00
parent caa1a41065
commit 8e93833321
10 changed files with 157 additions and 82 deletions

View File

@@ -2,7 +2,7 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
@@ -35,6 +35,7 @@ knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
for error in exc.errors():
@@ -102,7 +103,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@@ -176,6 +177,11 @@ async def dialogue_history_messages(con_uid: str):
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
@@ -204,32 +210,53 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat_param.update({"knowledge_space": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
headers = {
# "Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# "Transfer-Encoding": "chunked",
}
if not chat.prompt_template.stream_out:
return chat.nostream_call()
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
else:
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(stream_generator(chat), background=background_tasks)
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
background=background_tasks)
def release_model_semaphore():
model_semaphore.release()
def stream_generator(chat):
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.memory.append(chat.current_message)
def message2Vo(message: dict, order) -> MessageVo:
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
return MessageVo(role=message['type'], context=message['data']['content'], order=order)