mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
WEB API independent
This commit is contained in:
parent
ca6afbd445
commit
8119edd5ad
@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
conversation_from_dict,
|
||||
_conversation_to_dic,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.formatting import MyEncoder
|
||||
@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||
return cursor.fetchone()
|
||||
|
||||
content = cursor.fetchone()
|
||||
if content:
|
||||
return cursor.fetchone()[0]
|
||||
else:
|
||||
return None
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
if context:
|
||||
conversations: List[OnceConversation] = json.loads(context[0])
|
||||
conversations: List[OnceConversation] = json.loads(context)
|
||||
return conversations
|
||||
return []
|
||||
|
||||
@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
conversations: List[OnceConversation] = []
|
||||
if context:
|
||||
conversations = json.load(context)
|
||||
conversations.append(once_message)
|
||||
conversations = json.loads(context)
|
||||
conversations.append(_conversation_to_dic(once_message))
|
||||
cursor = self.connect.cursor()
|
||||
if context:
|
||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||
else:
|
||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
|
@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None):
|
||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||
new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard,
|
||||
ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||
for scene in new_modes:
|
||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||
@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
|
||||
def get_db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
params:dict = {}
|
||||
params: dict = {}
|
||||
for name in dbs:
|
||||
params.update({name: name})
|
||||
return params
|
||||
@ -108,9 +109,9 @@ def knowledge_list():
|
||||
|
||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
if ChatScene.ChatDb.value == chat_mode:
|
||||
if ChatScene.ChatWithDbQA.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatData.value == chat_mode:
|
||||
elif ChatScene.ChatWithDbExecute.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
"user_input": dialogue.user_input,
|
||||
}
|
||||
|
||||
if ChatScene.ChatDb == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatData == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||
chat_param.update("plugin_selector", dialogue.select_param)
|
||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("knowledge_name", dialogue.select_param)
|
||||
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
|
||||
chat_param.update({"plugin_selector": dialogue.select_param})
|
||||
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
|
||||
chat_param.update({"knowledge_name": dialogue.select_param})
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
if not chat.prompt_template.stream_out:
|
||||
return non_stream_response(chat)
|
||||
else:
|
||||
# generator = stream_generator(chat)
|
||||
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
||||
# return result
|
||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||
|
||||
|
||||
@ -196,12 +194,6 @@ def stream_generator(chat):
|
||||
chat.memory.append(chat.current_message)
|
||||
|
||||
|
||||
# def stream_response(chat):
|
||||
# logger.info("stream out start!")
|
||||
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||
# return api_response
|
||||
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
|
@ -18,8 +18,8 @@ class ChatScene(Enum):
|
||||
ChatNormal = "chat_normal"
|
||||
ChatDashboard = "chat_dashboard"
|
||||
ChatKnowledge = "chat_knowledge"
|
||||
ChatDb = "chat_db"
|
||||
ChatData= "chat_data"
|
||||
# ChatDb = "chat_db"
|
||||
# ChatData= "chat_data"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
|
@ -75,7 +75,7 @@ class BaseChat(ABC):
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||
self.current_tokens_used: int = 0
|
||||
### load chat_session_id's chat historys
|
||||
@ -95,7 +95,7 @@ class BaseChat(ABC):
|
||||
def generate_input_values(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
|
Loading…
Reference in New Issue
Block a user