mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
WEB API independent
This commit is contained in:
parent
ca6afbd445
commit
8119edd5ad
@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
|||||||
from pilot.scene.message import (
|
from pilot.scene.message import (
|
||||||
OnceConversation,
|
OnceConversation,
|
||||||
conversation_from_dict,
|
conversation_from_dict,
|
||||||
|
_conversation_to_dic,
|
||||||
conversations_to_dict,
|
conversations_to_dict,
|
||||||
)
|
)
|
||||||
from pilot.common.formatting import MyEncoder
|
from pilot.common.formatting import MyEncoder
|
||||||
@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||||
return cursor.fetchone()
|
content = cursor.fetchone()
|
||||||
|
if content:
|
||||||
|
return cursor.fetchone()[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
def messages(self) -> List[OnceConversation]:
|
def messages(self) -> List[OnceConversation]:
|
||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
if context:
|
if context:
|
||||||
conversations: List[OnceConversation] = json.loads(context[0])
|
conversations: List[OnceConversation] = json.loads(context)
|
||||||
return conversations
|
return conversations
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
conversations: List[OnceConversation] = []
|
conversations: List[OnceConversation] = []
|
||||||
if context:
|
if context:
|
||||||
conversations = json.load(context)
|
conversations = json.loads(context)
|
||||||
conversations.append(once_message)
|
conversations.append(_conversation_to_dic(once_message))
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
if context:
|
if context:
|
||||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||||
else:
|
else:
|
||||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None):
|
|||||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||||
async def dialogue_scenes():
|
async def dialogue_scenes():
|
||||||
scene_vos: List[ChatSceneVo] = []
|
scene_vos: List[ChatSceneVo] = []
|
||||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard,
|
||||||
|
ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||||
for scene in new_modes:
|
for scene in new_modes:
|
||||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||||
@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
|
|||||||
def get_db_list():
|
def get_db_list():
|
||||||
db = CFG.local_db
|
db = CFG.local_db
|
||||||
dbs = db.get_database_list()
|
dbs = db.get_database_list()
|
||||||
params:dict = {}
|
params: dict = {}
|
||||||
for name in dbs:
|
for name in dbs:
|
||||||
params.update({name: name})
|
params.update({name: name})
|
||||||
return params
|
return params
|
||||||
@ -108,9 +109,9 @@ def knowledge_list():
|
|||||||
|
|
||||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||||
if ChatScene.ChatDb.value == chat_mode:
|
if ChatScene.ChatWithDbQA.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
elif ChatScene.ChatData.value == chat_mode:
|
elif ChatScene.ChatWithDbExecute.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
elif ChatScene.ChatDashboard.value == chat_mode:
|
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
"user_input": dialogue.user_input,
|
"user_input": dialogue.user_input,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ChatScene.ChatDb == dialogue.chat_mode:
|
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatData == dialogue.chat_mode:
|
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
|
||||||
chat_param.update("plugin_selector", dialogue.select_param)
|
chat_param.update({"plugin_selector": dialogue.select_param})
|
||||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
|
||||||
chat_param.update("knowledge_name", dialogue.select_param)
|
chat_param.update({"knowledge_name": dialogue.select_param})
|
||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
return non_stream_response(chat)
|
return non_stream_response(chat)
|
||||||
else:
|
else:
|
||||||
# generator = stream_generator(chat)
|
|
||||||
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
|
||||||
# return result
|
|
||||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
@ -196,12 +194,6 @@ def stream_generator(chat):
|
|||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
|
||||||
|
|
||||||
# def stream_response(chat):
|
|
||||||
# logger.info("stream out start!")
|
|
||||||
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
|
||||||
# return api_response
|
|
||||||
|
|
||||||
|
|
||||||
def message2Vo(message: dict, order) -> MessageVo:
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||||
|
@ -18,8 +18,8 @@ class ChatScene(Enum):
|
|||||||
ChatNormal = "chat_normal"
|
ChatNormal = "chat_normal"
|
||||||
ChatDashboard = "chat_dashboard"
|
ChatDashboard = "chat_dashboard"
|
||||||
ChatKnowledge = "chat_knowledge"
|
ChatKnowledge = "chat_knowledge"
|
||||||
ChatDb = "chat_db"
|
# ChatDb = "chat_db"
|
||||||
ChatData= "chat_data"
|
# ChatData= "chat_data"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_valid_mode(mode):
|
def is_valid_mode(mode):
|
||||||
|
@ -75,7 +75,7 @@ class BaseChat(ABC):
|
|||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
self.chat_mode.value
|
self.chat_mode.value
|
||||||
]
|
]
|
||||||
self.history_message: List[OnceConversation] = []
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
### load chat_session_id's chat historys
|
### load chat_session_id's chat historys
|
||||||
@ -95,7 +95,7 @@ class BaseChat(ABC):
|
|||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
Loading…
Reference in New Issue
Block a user