Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework

This commit is contained in:
aries_ckt
2023-06-30 10:43:35 +08:00
20 changed files with 79 additions and 42 deletions

View File

@@ -2,6 +2,7 @@ import uuid
import json
import asyncio
import time
import os
from fastapi import (
APIRouter,
Request,
@@ -12,11 +13,11 @@ from fastapi import (
BackgroundTasks,
)
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.openapi.api_v1.api_view_model import (
@@ -46,6 +47,7 @@ knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
static_file_path = os.path.join(os.getcwd(), "server/static")
async def validation_exception_handler(request: Request, exc: RequestValidationError):
@@ -95,6 +97,10 @@ def knowledge_list():
return params
@router.get("/")
async def read_main():
return FileResponse(f"{static_file_path}/test.html")
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
@@ -111,8 +117,6 @@ async def dialogue_list(response: Response, user_id: str = None):
summary = item.get("summary")
chat_mode = item.get("chat_mode")
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=summary,
@@ -147,7 +151,6 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
@@ -155,6 +158,7 @@ async def dialogue_new(
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode:
@@ -274,15 +278,15 @@ async def stream_generator(chat):
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"

View File

@@ -124,6 +124,7 @@ class KnowledgeDocumentDao:
updated_space = session.merge(document)
session.commit()
return updated_space.id
#
# def delete_knowledge_document(self, document_id: int):
# cursor = self.conn.cursor()

View File

@@ -114,8 +114,13 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name:
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync")
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),