mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 05:32:32 +00:00
style:fmt
This commit is contained in:
parent
98a94268f0
commit
1b49267976
@ -217,7 +217,9 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/file/load")
|
||||
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
|
||||
async def params_load(
|
||||
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
|
||||
):
|
||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
@ -237,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file:
|
||||
)
|
||||
## chat prepare
|
||||
dialogue = ConversationVo(
|
||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
|
||||
conv_uid=conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
select_param=doc_file.filename,
|
||||
model_name=model_name,
|
||||
)
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
resp = await chat.prepare()
|
||||
@ -264,7 +269,8 @@ def get_hist_messages(conv_uid: str):
|
||||
print(f"once:{once}")
|
||||
model_name = once.get("model_name", CFG.LLM_MODEL)
|
||||
once_message_vos = [
|
||||
message2Vo(element, once["chat_order"], model_name) for element in once["messages"]
|
||||
message2Vo(element, once["chat_order"], model_name)
|
||||
for element in once["messages"]
|
||||
]
|
||||
message_vos.extend(once_message_vos)
|
||||
return message_vos
|
||||
@ -293,9 +299,11 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"current_user_input": dialogue.user_input,
|
||||
"select_param": dialogue.select_param,
|
||||
"model_name": dialogue.model_name
|
||||
"model_name": dialogue.model_name,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(
|
||||
dialogue.chat_mode, **{"chat_param": chat_param}
|
||||
)
|
||||
return chat
|
||||
|
||||
|
||||
@ -314,7 +322,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
# dialogue.model_name = CFG.LLM_MODEL
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
|
||||
print(
|
||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||
)
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
# background_tasks = BackgroundTasks()
|
||||
# background_tasks.add_task(release_model_semaphore)
|
||||
@ -344,9 +354,10 @@ async def model_types():
|
||||
print(f"/controller/model/types")
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
|
||||
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
|
||||
)
|
||||
types = set()
|
||||
if response.status_code == 200:
|
||||
@ -387,5 +398,8 @@ async def stream_generator(chat):
|
||||
|
||||
def message2Vo(message: dict, order, model_name) -> MessageVo:
|
||||
return MessageVo(
|
||||
role=message["type"], context=message["data"]["content"], order=order, model_name=model_name
|
||||
role=message["type"],
|
||||
context=message["data"]["content"],
|
||||
order=order,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
@ -32,13 +32,13 @@ class BaseChat(ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self, chat_param: Dict
|
||||
):
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.chat_session_id = chat_param["chat_session_id"]
|
||||
self.chat_mode = chat_param["chat_mode"]
|
||||
self.current_user_input: str = chat_param["current_user_input"]
|
||||
self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||
self.llm_model = (
|
||||
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||
)
|
||||
self.llm_echo = False
|
||||
|
||||
### load prompt template
|
||||
@ -58,7 +58,9 @@ class BaseChat(ABC):
|
||||
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
|
||||
self.current_message: OnceConversation = OnceConversation(
|
||||
self.chat_mode.value()
|
||||
)
|
||||
self.current_message.model_name = self.llm_model
|
||||
if chat_param["select_param"]:
|
||||
if len(self.chat_mode.param_types()) > 0:
|
||||
|
@ -21,16 +21,11 @@ class ChatDashboard(BaseChat):
|
||||
report_name: str
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_param: Dict
|
||||
):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
||||
super().__init__(
|
||||
chat_param=chat_param
|
||||
)
|
||||
super().__init__(chat_param=chat_param)
|
||||
if not self.db_name:
|
||||
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
||||
self.db_name = self.db_name
|
||||
|
@ -37,9 +37,7 @@ class ChatExcel(BaseChat):
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chat_param=chat_param
|
||||
)
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
"""
|
||||
|
@ -43,9 +43,7 @@ class ExcelLearning(BaseChat):
|
||||
"select_param": select_param,
|
||||
"model_name": model_name,
|
||||
}
|
||||
super().__init__(
|
||||
chat_param=chat_param
|
||||
)
|
||||
super().__init__(chat_param=chat_param)
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
||||
|
@ -18,9 +18,7 @@ class ChatWithDbQA(BaseChat):
|
||||
""" """
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
||||
super().__init__(
|
||||
chat_param=chat_param
|
||||
)
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
if self.db_name:
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||
|
@ -18,9 +18,7 @@ class ChatWithPlugin(BaseChat):
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.plugin_selector = chat_param.select_param
|
||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||
super().__init__(
|
||||
chat_param=chat_param
|
||||
)
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# 加载插件中可用命令
|
||||
|
Loading…
Reference in New Issue
Block a user