style:fmt

This commit is contained in:
aries_ckt
2023-08-29 20:28:16 +08:00
parent 71b9cd14a6
commit e70f05aea1
8 changed files with 42 additions and 42 deletions

View File

@@ -170,7 +170,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
): ):
conv_vo = __new_conversation(chat_mode, user_id) conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo) return Result.succ(conv_vo)
@@ -201,7 +201,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)): if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)) os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
with NamedTemporaryFile( with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False
) as tmp: ) as tmp:
tmp.write(await doc_file.read()) tmp.write(await doc_file.read())
tmp_path = tmp.name tmp_path = tmp.name

View File

@@ -39,7 +39,7 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) @router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables( async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = "" db_name: str, page_index: int, page_size: int, search_str: str = ""
): ):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}") logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
@@ -285,7 +285,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
find_chart = list( find_chart = list(
filter( filter(
lambda x: x["chart_name"] lambda x: x["chart_name"]
== chart_edit_context.chart_title, == chart_edit_context.chart_title,
charts, charts,
) )
)[0] )[0]

View File

@@ -155,7 +155,7 @@ class BaseOutputParser(ABC):
if i < 0: if i < 0:
return None return None
count = 1 count = 1
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "]": if c == "]":
count -= 1 count -= 1
elif c == "[": elif c == "[":
@@ -163,13 +163,13 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 assert count == 0
return s[i: j + 1] return s[i : j + 1]
else: else:
i = s.find("{") i = s.find("{")
if i < 0: if i < 0:
return None return None
count = 1 count = 1
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}": if c == "}":
count -= 1 count -= 1
elif c == "{": elif c == "{":
@@ -177,7 +177,7 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 assert count == 0
return s[i: j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@@ -194,9 +194,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):] cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):] cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()

View File

@@ -60,7 +60,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@@ -305,7 +305,7 @@ class BaseChat(ABC):
system_messages = [] system_messages = []
for system_conv in system_convs: for system_conv in system_convs:
system_text += ( system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep system_conv.type + ":" + system_conv.content + self.prompt_template.sep
) )
system_messages.append( system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content) ModelMessage(role=system_conv.type, content=system_conv.content)
@@ -317,7 +317,7 @@ class BaseChat(ABC):
user_messages = [] user_messages = []
if user_conv: if user_conv:
user_text = ( user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep user_conv.type + ":" + user_conv.content + self.prompt_template.sep
) )
user_messages.append( user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content) ModelMessage(role=user_conv.type, content=user_conv.content)
@@ -339,10 +339,10 @@ class BaseChat(ABC):
message_type = round_message["type"] message_type = round_message["type"]
message_content = round_message["data"]["content"] message_content = round_message["data"]["content"]
example_text += ( example_text += (
message_type message_type
+ ":" + ":"
+ message_content + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
example_messages.append( example_messages.append(
ModelMessage(role=message_type, content=message_content) ModelMessage(role=message_type, content=message_content)
@@ -363,10 +363,10 @@ class BaseChat(ABC):
message_type = first_message["type"] message_type = first_message["type"]
message_content = first_message["data"]["content"] message_content = first_message["data"]["content"]
history_text += ( history_text += (
message_type message_type
+ ":" + ":"
+ message_content + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append( history_messages.append(
ModelMessage(role=message_type, content=message_content) ModelMessage(role=message_type, content=message_content)
@@ -382,10 +382,10 @@ class BaseChat(ABC):
message_type = round_message["type"] message_type = round_message["type"]
message_content = round_message["data"]["content"] message_content = round_message["data"]["content"]
history_text += ( history_text += (
message_type message_type
+ ":" + ":"
+ message_content + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append( history_messages.append(
ModelMessage( ModelMessage(
@@ -405,10 +405,10 @@ class BaseChat(ABC):
message_type = message["type"] message_type = message["type"]
message_content = message["data"]["content"] message_content = message["data"]["content"]
history_text += ( history_text += (
message_type message_type
+ ":" + ":"
+ message_content + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append( history_messages.append(
ModelMessage(role=message_type, content=message_content) ModelMessage(role=message_type, content=message_content)

View File

@@ -167,7 +167,7 @@ def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
def _parse_model_messages( def _parse_model_messages(
messages: List[ModelMessage], messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]: ) -> Tuple[str, List[str], List[List[str, str]]]:
""" " """ "
Parameters: Parameters:

View File

@@ -22,11 +22,11 @@ class ChatDashboard(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(
self, self,
chat_session_id, chat_session_id,
user_input, user_input,
select_param: str = "", select_param: str = "",
report_name: str = "report", report_name: str = "report",
): ):
""" """ """ """
self.db_name = select_param self.db_name = select_param

View File

@@ -24,12 +24,12 @@ class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value() chat_scene: str = ChatScene.ExcelLearning.value()
def __init__( def __init__(
self, self,
chat_session_id, chat_session_id,
user_input, user_input,
parent_mode: Any = None, parent_mode: Any = None,
select_param: str = None, select_param: str = None,
excel_reader: Any = None, excel_reader: Any = None,
): ):
chat_mode = ChatScene.ExcelLearning chat_mode = ChatScene.ExcelLearning
""" """ """ """

View File

@@ -47,7 +47,7 @@ class LearningExcelOutputParser(BaseOutputParser):
keys = item.keys() keys = item.keys()
for key in keys: for key in keys:
html_colunms = ( html_colunms = (
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
) )
html_plans = f"### **分析计划**\n" html_plans = f"### **分析计划**\n"