mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 14:06:43 +00:00
style:fmt
This commit is contained in:
parent
98a94268f0
commit
1b49267976
@ -217,7 +217,9 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/mode/params/file/load")
|
@router.post("/v1/chat/mode/params/file/load")
|
||||||
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
|
async def params_load(
|
||||||
|
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
|
||||||
|
):
|
||||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||||
try:
|
try:
|
||||||
if doc_file:
|
if doc_file:
|
||||||
@ -237,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file:
|
|||||||
)
|
)
|
||||||
## chat prepare
|
## chat prepare
|
||||||
dialogue = ConversationVo(
|
dialogue = ConversationVo(
|
||||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
|
conv_uid=conv_uid,
|
||||||
|
chat_mode=chat_mode,
|
||||||
|
select_param=doc_file.filename,
|
||||||
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
resp = await chat.prepare()
|
resp = await chat.prepare()
|
||||||
@ -264,7 +269,8 @@ def get_hist_messages(conv_uid: str):
|
|||||||
print(f"once:{once}")
|
print(f"once:{once}")
|
||||||
model_name = once.get("model_name", CFG.LLM_MODEL)
|
model_name = once.get("model_name", CFG.LLM_MODEL)
|
||||||
once_message_vos = [
|
once_message_vos = [
|
||||||
message2Vo(element, once["chat_order"], model_name) for element in once["messages"]
|
message2Vo(element, once["chat_order"], model_name)
|
||||||
|
for element in once["messages"]
|
||||||
]
|
]
|
||||||
message_vos.extend(once_message_vos)
|
message_vos.extend(once_message_vos)
|
||||||
return message_vos
|
return message_vos
|
||||||
@ -293,9 +299,11 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
|||||||
"chat_session_id": dialogue.conv_uid,
|
"chat_session_id": dialogue.conv_uid,
|
||||||
"current_user_input": dialogue.user_input,
|
"current_user_input": dialogue.user_input,
|
||||||
"select_param": dialogue.select_param,
|
"select_param": dialogue.select_param,
|
||||||
"model_name": dialogue.model_name
|
"model_name": dialogue.model_name,
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
|
chat: BaseChat = CHAT_FACTORY.get_implementation(
|
||||||
|
dialogue.chat_mode, **{"chat_param": chat_param}
|
||||||
|
)
|
||||||
return chat
|
return chat
|
||||||
|
|
||||||
|
|
||||||
@ -314,7 +322,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
|||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
# dialogue.model_name = CFG.LLM_MODEL
|
# dialogue.model_name = CFG.LLM_MODEL
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
|
print(
|
||||||
|
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||||
|
)
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
# background_tasks = BackgroundTasks()
|
# background_tasks = BackgroundTasks()
|
||||||
# background_tasks.add_task(release_model_semaphore)
|
# background_tasks.add_task(release_model_semaphore)
|
||||||
@ -344,6 +354,7 @@ async def model_types():
|
|||||||
print(f"/controller/model/types")
|
print(f"/controller/model/types")
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
|
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
|
||||||
@ -387,5 +398,8 @@ async def stream_generator(chat):
|
|||||||
|
|
||||||
def message2Vo(message: dict, order, model_name) -> MessageVo:
|
def message2Vo(message: dict, order, model_name) -> MessageVo:
|
||||||
return MessageVo(
|
return MessageVo(
|
||||||
role=message["type"], context=message["data"]["content"], order=order, model_name=model_name
|
role=message["type"],
|
||||||
|
context=message["data"]["content"],
|
||||||
|
order=order,
|
||||||
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
|
@ -32,13 +32,13 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_param: Dict):
|
||||||
self, chat_param: Dict
|
|
||||||
):
|
|
||||||
self.chat_session_id = chat_param["chat_session_id"]
|
self.chat_session_id = chat_param["chat_session_id"]
|
||||||
self.chat_mode = chat_param["chat_mode"]
|
self.chat_mode = chat_param["chat_mode"]
|
||||||
self.current_user_input: str = chat_param["current_user_input"]
|
self.current_user_input: str = chat_param["current_user_input"]
|
||||||
self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
self.llm_model = (
|
||||||
|
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||||
|
)
|
||||||
self.llm_echo = False
|
self.llm_echo = False
|
||||||
|
|
||||||
### load prompt template
|
### load prompt template
|
||||||
@ -58,7 +58,9 @@ class BaseChat(ABC):
|
|||||||
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||||
|
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
|
self.current_message: OnceConversation = OnceConversation(
|
||||||
|
self.chat_mode.value()
|
||||||
|
)
|
||||||
self.current_message.model_name = self.llm_model
|
self.current_message.model_name = self.llm_model
|
||||||
if chat_param["select_param"]:
|
if chat_param["select_param"]:
|
||||||
if len(self.chat_mode.param_types()) > 0:
|
if len(self.chat_mode.param_types()) > 0:
|
||||||
|
@ -21,16 +21,11 @@ class ChatDashboard(BaseChat):
|
|||||||
report_name: str
|
report_name: str
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_param: Dict):
|
||||||
self,
|
|
||||||
chat_param: Dict
|
|
||||||
):
|
|
||||||
""" """
|
""" """
|
||||||
self.db_name = chat_param["select_param"]
|
self.db_name = chat_param["select_param"]
|
||||||
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
||||||
super().__init__(
|
super().__init__(chat_param=chat_param)
|
||||||
chat_param=chat_param
|
|
||||||
)
|
|
||||||
if not self.db_name:
|
if not self.db_name:
|
||||||
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
||||||
self.db_name = self.db_name
|
self.db_name = self.db_name
|
||||||
|
@ -37,9 +37,7 @@ class ChatExcel(BaseChat):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(chat_param=chat_param)
|
||||||
chat_param=chat_param
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -43,9 +43,7 @@ class ExcelLearning(BaseChat):
|
|||||||
"select_param": select_param,
|
"select_param": select_param,
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
super().__init__(
|
super().__init__(chat_param=chat_param)
|
||||||
chat_param=chat_param
|
|
||||||
)
|
|
||||||
if parent_mode:
|
if parent_mode:
|
||||||
self.current_message.chat_mode = parent_mode.value()
|
self.current_message.chat_mode = parent_mode.value()
|
||||||
|
|
||||||
|
@ -18,9 +18,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
""" """
|
""" """
|
||||||
self.db_name = chat_param["select_param"]
|
self.db_name = chat_param["select_param"]
|
||||||
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
||||||
super().__init__(
|
super().__init__(chat_param=chat_param)
|
||||||
chat_param=chat_param
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.db_name:
|
if self.db_name:
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
|
@ -18,9 +18,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
def __init__(self, chat_param: Dict):
|
def __init__(self, chat_param: Dict):
|
||||||
self.plugin_selector = chat_param.select_param
|
self.plugin_selector = chat_param.select_param
|
||||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||||
super().__init__(
|
super().__init__(chat_param=chat_param)
|
||||||
chat_param=chat_param
|
|
||||||
)
|
|
||||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||||
# 加载插件中可用命令
|
# 加载插件中可用命令
|
||||||
|
Loading…
Reference in New Issue
Block a user