style:fmt

This commit is contained in:
aries_ckt 2023-09-13 14:20:04 +08:00
parent 98a94268f0
commit 1b49267976
7 changed files with 35 additions and 32 deletions

View File

@ -217,7 +217,9 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load") @router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)): async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}") print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try: try:
if doc_file: if doc_file:
@ -237,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file:
) )
## chat prepare ## chat prepare
dialogue = ConversationVo( dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
) )
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = get_chat_instance(dialogue)
resp = await chat.prepare() resp = await chat.prepare()
@ -264,7 +269,8 @@ def get_hist_messages(conv_uid: str):
print(f"once:{once}") print(f"once:{once}")
model_name = once.get("model_name", CFG.LLM_MODEL) model_name = once.get("model_name", CFG.LLM_MODEL)
once_message_vos = [ once_message_vos = [
message2Vo(element, once["chat_order"], model_name) for element in once["messages"] message2Vo(element, once["chat_order"], model_name)
for element in once["messages"]
] ]
message_vos.extend(once_message_vos) message_vos.extend(once_message_vos)
return message_vos return message_vos
@ -293,9 +299,11 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,
"current_user_input": dialogue.user_input, "current_user_input": dialogue.user_input,
"select_param": dialogue.select_param, "select_param": dialogue.select_param,
"model_name": dialogue.model_name "model_name": dialogue.model_name,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param}
)
return chat return chat
@ -314,7 +322,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL # dialogue.model_name = CFG.LLM_MODEL
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}") print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks() # background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore) # background_tasks.add_task(release_model_semaphore)
@ -344,6 +354,7 @@ async def model_types():
print(f"/controller/model/types") print(f"/controller/model/types")
try: try:
import httpx import httpx
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true", f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
@ -387,5 +398,8 @@ async def stream_generator(chat):
def message2Vo(message: dict, order, model_name) -> MessageVo: def message2Vo(message: dict, order, model_name) -> MessageVo:
return MessageVo( return MessageVo(
role=message["type"], context=message["data"]["content"], order=order, model_name=model_name role=message["type"],
context=message["data"]["content"],
order=order,
model_name=model_name,
) )

View File

@ -32,13 +32,13 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(self, chat_param: Dict):
self, chat_param: Dict
):
self.chat_session_id = chat_param["chat_session_id"] self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"] self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"] self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL self.llm_model = (
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False self.llm_echo = False
### load prompt template ### load prompt template
@ -58,7 +58,9 @@ class BaseChat(ABC):
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"]) self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages() self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value()) self.current_message: OnceConversation = OnceConversation(
self.chat_mode.value()
)
self.current_message.model_name = self.llm_model self.current_message.model_name = self.llm_model
if chat_param["select_param"]: if chat_param["select_param"]:
if len(self.chat_mode.param_types()) > 0: if len(self.chat_mode.param_types()) > 0:

View File

@ -21,16 +21,11 @@ class ChatDashboard(BaseChat):
report_name: str report_name: str
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_param: Dict):
self,
chat_param: Dict
):
""" """ """ """
self.db_name = chat_param["select_param"] self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__( super().__init__(chat_param=chat_param)
chat_param=chat_param
)
if not self.db_name: if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!") raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name self.db_name = self.db_name

View File

@ -37,9 +37,7 @@ class ChatExcel(BaseChat):
) )
) )
super().__init__( super().__init__(chat_param=chat_param)
chat_param=chat_param
)
def _generate_command_string(self, command: Dict[str, Any]) -> str: def _generate_command_string(self, command: Dict[str, Any]) -> str:
""" """

View File

@ -43,9 +43,7 @@ class ExcelLearning(BaseChat):
"select_param": select_param, "select_param": select_param,
"model_name": model_name, "model_name": model_name,
} }
super().__init__( super().__init__(chat_param=chat_param)
chat_param=chat_param
)
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()

View File

@ -18,9 +18,7 @@ class ChatWithDbQA(BaseChat):
""" """ """ """
self.db_name = chat_param["select_param"] self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__( super().__init__(chat_param=chat_param)
chat_param=chat_param
)
if self.db_name: if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)

View File

@ -18,9 +18,7 @@ class ChatWithPlugin(BaseChat):
def __init__(self, chat_param: Dict): def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatExecution chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__( super().__init__(chat_param=chat_param)
chat_param=chat_param
)
self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令 # 加载插件中可用命令