WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-27 17:45:07 +08:00
parent ca6afbd445
commit 8119edd5ad
5 changed files with 32 additions and 36 deletions

View File

@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
_conversation_to_dic,
conversations_to_dict,
)
from pilot.common.formatting import MyEncoder
@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
return cursor.fetchone()
content = cursor.fetchone()
if content:
return cursor.fetchone()[0]
else:
return None
def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context:
conversations: List[OnceConversation] = json.loads(context[0])
conversations: List[OnceConversation] = json.loads(context)
return conversations
return []
@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = []
if context:
conversations = json.load(context)
conversations.append(once_message)
conversations = json.loads(context)
conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor()
if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
cursor.commit()
self.connect.commit()

View File

@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None):
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard,
ChatScene.ChatKnowledge, ChatScene.ChatExecution]
for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params:dict = {}
params: dict = {}
for name in dbs:
params.update({name: name})
return params
@ -108,9 +109,9 @@ def knowledge_list():
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode:
if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatData.value == chat_mode:
elif ChatScene.ChatWithDbExecute.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list())
@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"user_input": dialogue.user_input,
}
if ChatScene.ChatDb == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatData == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatDashboard == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatExecution == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param)
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
chat_param.update("knowledge_name", dialogue.select_param)
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
chat_param.update({"knowledge_name": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out:
return non_stream_response(chat)
else:
# generator = stream_generator(chat)
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
# return result
return StreamingResponse(stream_generator(chat), media_type="text/plain")
@ -196,12 +194,6 @@ def stream_generator(chat):
chat.memory.append(chat.current_message)
# def stream_response(chat):
# logger.info("stream out start!")
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
# return api_response
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order)

View File

@ -18,8 +18,8 @@ class ChatScene(Enum):
ChatNormal = "chat_normal"
ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db"
ChatData= "chat_data"
# ChatDb = "chat_db"
# ChatData= "chat_data"
@staticmethod
def is_valid_mode(mode):

View File

@ -75,7 +75,7 @@ class BaseChat(ABC):
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = []
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
self.current_tokens_used: int = 0
### load chat_session_id's chat historys
@ -95,7 +95,7 @@ class BaseChat(ABC):
def generate_input_values(self):
pass
@abstractmethod
def do_action(self, prompt_response):
return prompt_response

View File

@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件
CFG = Config()