WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-27 17:45:07 +08:00
parent ca6afbd445
commit 8119edd5ad
5 changed files with 32 additions and 36 deletions

View File

@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import ( from pilot.scene.message import (
OnceConversation, OnceConversation,
conversation_from_dict, conversation_from_dict,
_conversation_to_dic,
conversations_to_dict, conversations_to_dict,
) )
from pilot.common.formatting import MyEncoder from pilot.common.formatting import MyEncoder
@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __get_messages_by_conv_uid(self, conv_uid: str): def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
return cursor.fetchone() content = cursor.fetchone()
if content:
return cursor.fetchone()[0]
else:
return None
def messages(self) -> List[OnceConversation]: def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id) context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context: if context:
conversations: List[OnceConversation] = json.loads(context[0]) conversations: List[OnceConversation] = json.loads(context)
return conversations return conversations
return [] return []
@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
context = self.__get_messages_by_conv_uid(self.chat_seesion_id) context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = [] conversations: List[OnceConversation] = []
if context: if context:
conversations = json.load(context) conversations = json.loads(context)
conversations.append(once_message) conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor() cursor = self.connect.cursor()
if context: if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
else: else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()

View File

@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None):
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) @router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes(): async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = [] scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard,
ChatScene.ChatKnowledge, ChatScene.ChatExecution]
for scene in new_modes: for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list(): def get_db_list():
db = CFG.local_db db = CFG.local_db
dbs = db.get_database_list() dbs = db.get_database_list()
params:dict = {} params: dict = {}
for name in dbs: for name in dbs:
params.update({name: name}) params.update({name: name})
return params return params
@ -108,9 +109,9 @@ def knowledge_list():
@router.post('/v1/chat/mode/params/list', response_model=Result[dict]) @router.post('/v1/chat/mode/params/list', response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value): async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode: if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
elif ChatScene.ChatData.value == chat_mode: elif ChatScene.ChatWithDbExecute.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode: elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"user_input": dialogue.user_input, "user_input": dialogue.user_input,
} }
if ChatScene.ChatDb == dialogue.chat_mode: if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatData == dialogue.chat_mode: elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard == dialogue.chat_mode: elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatExecution == dialogue.chat_mode: elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param) chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge == dialogue.chat_mode: elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
chat_param.update("knowledge_name", dialogue.select_param) chat_param.update({"knowledge_name": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return non_stream_response(chat) return non_stream_response(chat)
else: else:
# generator = stream_generator(chat)
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
# return result
return StreamingResponse(stream_generator(chat), media_type="text/plain") return StreamingResponse(stream_generator(chat), media_type="text/plain")
@ -196,12 +194,6 @@ def stream_generator(chat):
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)
# def stream_response(chat):
# logger.info("stream out start!")
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
# return api_response
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order) return MessageVo(role=message['type'], context=message['data']['content'], order=order)

View File

@ -18,8 +18,8 @@ class ChatScene(Enum):
ChatNormal = "chat_normal" ChatNormal = "chat_normal"
ChatDashboard = "chat_dashboard" ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge" ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db" # ChatDb = "chat_db"
ChatData= "chat_data" # ChatData= "chat_data"
@staticmethod @staticmethod
def is_valid_mode(mode): def is_valid_mode(mode):

View File

@ -75,7 +75,7 @@ class BaseChat(ABC):
self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value self.chat_mode.value
] ]
self.history_message: List[OnceConversation] = [] self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_message: OnceConversation = OnceConversation(chat_mode.value)
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
### load chat_session_id's chat historys ### load chat_session_id's chat historys
@ -95,7 +95,7 @@ class BaseChat(ABC):
def generate_input_values(self): def generate_input_values(self):
pass pass
@abstractmethod
def do_action(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response

View File

@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件 # 加载插件
CFG = Config() CFG = Config()