style:fmt

This commit is contained in:
aries_ckt 2023-09-13 14:20:04 +08:00
parent 98a94268f0
commit 1b49267976
7 changed files with 35 additions and 32 deletions

View File

@ -217,7 +217,9 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
@ -237,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file:
)
## chat prepare
dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
)
chat: BaseChat = get_chat_instance(dialogue)
resp = await chat.prepare()
@ -264,7 +269,8 @@ def get_hist_messages(conv_uid: str):
print(f"once:{once}")
model_name = once.get("model_name", CFG.LLM_MODEL)
once_message_vos = [
message2Vo(element, once["chat_order"], model_name) for element in once["messages"]
message2Vo(element, once["chat_order"], model_name)
for element in once["messages"]
]
message_vos.extend(once_message_vos)
return message_vos
@ -293,9 +299,11 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
"chat_session_id": dialogue.conv_uid,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name
"model_name": dialogue.model_name,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param}
)
return chat
@ -314,7 +322,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
@ -344,9 +354,10 @@ async def model_types():
print(f"/controller/model/types")
try:
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true",
)
types = set()
if response.status_code == 200:
@ -387,5 +398,8 @@ async def stream_generator(chat):
def message2Vo(message: dict, order, model_name) -> MessageVo:
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order, model_name=model_name
role=message["type"],
context=message["data"]["content"],
order=order,
model_name=model_name,
)

View File

@ -32,13 +32,13 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self, chat_param: Dict
):
def __init__(self, chat_param: Dict):
self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
self.llm_model = (
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False
### load prompt template
@ -58,7 +58,9 @@ class BaseChat(ABC):
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
self.current_message: OnceConversation = OnceConversation(
self.chat_mode.value()
)
self.current_message.model_name = self.llm_model
if chat_param["select_param"]:
if len(self.chat_mode.param_types()) > 0:

View File

@ -21,16 +21,11 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""
def __init__(
self,
chat_param: Dict
):
def __init__(self, chat_param: Dict):
""" """
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__(
chat_param=chat_param
)
super().__init__(chat_param=chat_param)
if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name

View File

@ -37,9 +37,7 @@ class ChatExcel(BaseChat):
)
)
super().__init__(
chat_param=chat_param
)
super().__init__(chat_param=chat_param)
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""

View File

@ -43,9 +43,7 @@ class ExcelLearning(BaseChat):
"select_param": select_param,
"model_name": model_name,
}
super().__init__(
chat_param=chat_param
)
super().__init__(chat_param=chat_param)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()

View File

@ -18,9 +18,7 @@ class ChatWithDbQA(BaseChat):
""" """
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__(
chat_param=chat_param
)
super().__init__(chat_param=chat_param)
if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)

View File

@ -18,9 +18,7 @@ class ChatWithPlugin(BaseChat):
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__(
chat_param=chat_param
)
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令