From 8c93d355f583038977ad5c0a702af923570262a2 Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Thu, 1 Jun 2023 18:17:37 +0800 Subject: [PATCH] feature:db_summary --- pilot/agent/json_fix_llm.py | 2 - pilot/common/custom_data_structure.py | 4 +- pilot/common/sql_database.py | 59 ++++++++------- pilot/configs/config.py | 8 ++- pilot/connections/base.py | 4 +- pilot/connections/rdbms/mysql.py | 3 +- pilot/conversation.py | 3 +- pilot/language/lang_content_mapping.py | 1 - pilot/model/llm_out/chatglm_llm.py | 4 +- pilot/model/llm_out/guanaco_llm.py | 20 +++--- pilot/out_parser/base.py | 23 +++--- pilot/prompts/generator.py | 6 +- pilot/prompts/prompt_new.py | 10 +-- pilot/scene/base_chat.py | 64 ++++++++++------- pilot/scene/chat_db/auto_execute/chat.py | 24 ++++--- .../scene/chat_db/auto_execute/out_parser.py | 3 +- pilot/scene/chat_db/auto_execute/prompt.py | 1 - pilot/scene/chat_db/professional_qa/chat.py | 27 +++---- .../chat_db/professional_qa/out_parser.py | 2 +- pilot/scene/chat_db/professional_qa/prompt.py | 4 +- pilot/scene/chat_execution/chat.py | 52 +++++++++----- pilot/scene/chat_execution/out_parser.py | 5 +- pilot/scene/chat_execution/prompt.py | 4 +- pilot/scene/chat_factory.py | 1 + pilot/scene/chat_knowledge/custom/chat.py | 31 ++++---- .../scene/chat_knowledge/custom/out_parser.py | 2 +- pilot/scene/chat_knowledge/custom/prompt.py | 3 - pilot/scene/chat_knowledge/default/chat.py | 38 +++++----- .../chat_knowledge/default/out_parser.py | 2 +- pilot/scene/chat_knowledge/default/prompt.py | 3 - .../chat_knowledge/inner_db_summary/chat.py | 35 +++++---- .../inner_db_summary/out_parser.py | 10 +-- .../chat_knowledge/inner_db_summary/prompt.py | 19 ++--- pilot/scene/chat_knowledge/url/chat.py | 28 ++++---- pilot/scene/chat_knowledge/url/out_parser.py | 2 +- pilot/scene/chat_knowledge/url/prompt.py | 3 - pilot/scene/chat_normal/chat.py | 21 +++--- pilot/scene/chat_normal/out_parser.py | 2 +- pilot/scene/chat_normal/prompt.py | 2 - pilot/server/webserver.py | 72 ++++++++++++------- pilot/source_embedding/knowledge_embedding.py | 11 ++- pilot/summary/db_summary_client.py | 43 ++++------- pilot/summary/mysql_db_summary.py | 21 +++++- pilot/vector_store/connector.py | 1 + 44 files changed, 369 insertions(+), 314 deletions(-) diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py index 075634784..800c2858a 100644 --- a/pilot/agent/json_fix_llm.py +++ b/pilot/agent/json_fix_llm.py @@ -55,8 +55,6 @@ def fix_and_parse_json( logger.error("参数解析错误", e) - - def correct_json(json_to_load: str) -> str: """ Correct common JSON errors. diff --git a/pilot/common/custom_data_structure.py b/pilot/common/custom_data_structure.py index ca0892528..a8502143a 100644 --- a/pilot/common/custom_data_structure.py +++ b/pilot/common/custom_data_structure.py @@ -1,6 +1,7 @@ from collections import OrderedDict from collections import deque + class FixedSizeDict(OrderedDict): def __init__(self, max_size): super().__init__() @@ -11,6 +12,7 @@ class FixedSizeDict(OrderedDict): self.popitem(last=False) super().__setitem__(key, value) + class FixedSizeList: def __init__(self, max_size): self.max_size = max_size @@ -29,4 +31,4 @@ class FixedSizeList: return len(self.list) def __str__(self): - return str(list(self.list)) \ No newline at end of file + return str(list(self.list)) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index cc09e4328..bc3aa8340 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -1,6 +1,6 @@ from __future__ import annotations import sqlparse -import regex as re +import regex as re import warnings from typing import Any, Iterable, List, Optional from pydantic import BaseModel, Field, root_validator, validator, Extra @@ -283,16 +283,16 @@ class Database: return f"Error: {e}" def __write(self, session, write_sql): - print(f"Write[{write_sql}]") - db_cache = self.get_session_db(session) - result = session.execute(text(write_sql)) - session.commit() - #TODO Subsequent optimization of dynamically specified database submission loss target problem - session.execute(text(f"use `{db_cache}`")) - print(f"SQL[{write_sql}], result:{result.rowcount}") - return result.rowcount + print(f"Write[{write_sql}]") + db_cache = self.get_session_db(session) + result = session.execute(text(write_sql)) + session.commit() + # TODO Subsequent optimization of dynamically specified database submission loss target problem + session.execute(text(f"use `{db_cache}`")) + print(f"SQL[{write_sql}], result:{result.rowcount}") + return result.rowcount - def __query(self,session, query, fetch: str = "all"): + def __query(self, session, query, fetch: str = "all"): """ only for query Args: @@ -390,37 +390,44 @@ class Database: cmd_type = parts[0] # 根据命令类型进行处理 - if cmd_type == 'insert': - match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()) + if cmd_type == "insert": + match = re.match( + r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower() + ) if match: table_name, columns, values = match.groups() # 将字段列表和值列表分割为单独的字段和值 - columns = columns.split(',') - values = values.split(',') + columns = columns.split(",") + values = values.split(",") # 构造 WHERE 子句 - where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)]) - return f'SELECT * FROM {table_name} WHERE {where_clause}' + where_clause = " AND ".join( + [ + f"{col.strip()}={val.strip()}" + for col, val in zip(columns, values) + ] + ) + return f"SELECT * FROM {table_name} WHERE {where_clause}" - elif cmd_type == 'delete': + elif cmd_type == "delete": table_name = parts[2] # delete from ... # 返回一个select语句,它选择该表的所有数据 - return f'SELECT * FROM {table_name}' + return f"SELECT * FROM {table_name}" - elif cmd_type == 'update': + elif cmd_type == "update": table_name = parts[1] - set_idx = parts.index('set') - where_idx = parts.index('where') + set_idx = parts.index("set") + where_idx = parts.index("where") # 截取 `set` 子句中的字段名 - set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip() + set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip() # 截取 `where` 之后的条件语句 - where_clause = ' '.join(parts[where_idx + 1:]) + where_clause = " ".join(parts[where_idx + 1 :]) # 返回一个select语句,它选择更新的数据 - return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}' + return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}" else: raise ValueError(f"Unsupported SQL command type: {cmd_type}") def __sql_parse(self, sql): - sql = sql.strip() + sql = sql.strip() parsed = sqlparse.parse(sql)[0] sql_type = parsed.get_type() @@ -429,8 +436,6 @@ class Database: print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}") return parsed, ttype, sql_type - - def get_indexes(self, table_name): """Get table indexes about specified table.""" session = self._db_sessions() diff --git a/pilot/configs/config.py b/pilot/configs/config.py index c39f272f5..3762b43c1 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -103,8 +103,12 @@ class Config(metaclass=Singleton): else: self.plugins_denylist = [] ### Native SQL Execution Capability Control Configuration - self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True" - self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True" + self.NATIVE_SQL_CAN_RUN_DDL = ( + os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True" + ) + self.NATIVE_SQL_CAN_RUN_WRITE = ( + os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True" + ) ### Local database connection configuration self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") diff --git a/pilot/connections/base.py b/pilot/connections/base.py index 3905f410d..40f6430f7 100644 --- a/pilot/connections/base.py +++ b/pilot/connections/base.py @@ -11,11 +11,9 @@ class BaseConnect(BaseModel, ABC): type driver: str - def get_session(self, db_name: str): pass - def get_table_names(self) -> Iterable[str]: pass @@ -32,4 +30,4 @@ class BaseConnect(BaseModel, ABC): pass def run(self, session, command: str, fetch: str = "all") -> List: - pass \ No newline at end of file + pass diff --git a/pilot/connections/rdbms/mysql.py b/pilot/connections/rdbms/mysql.py index 9d99f3e9b..e50e9c679 100644 --- a/pilot/connections/rdbms/mysql.py +++ b/pilot/connections/rdbms/mysql.py @@ -11,8 +11,7 @@ class MySQLConnect(RDBMSDatabase): Usage: """ - type:str = "MySQL" + type: str = "MySQL" connect_url = "mysql+pymysql://" default_db = ["information_schema", "performance_schema", "sys", "mysql"] - diff --git a/pilot/conversation.py b/pilot/conversation.py index f03c13e31..40759ffc8 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -106,7 +106,7 @@ class Conversation: conv_default = Conversation( - system = None, + system=None, roles=("human", "ai"), messages=[], offset=0, @@ -298,7 +298,6 @@ chat_mode_title = { "sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"), "chat_use_plugin": get_lang_text("chat_use_plugin"), "knowledge_qa": get_lang_text("knowledge_qa"), - } conversation_sql_mode = { diff --git a/pilot/language/lang_content_mapping.py b/pilot/language/lang_content_mapping.py index bcea7ed3c..86aa3fa3c 100644 --- a/pilot/language/lang_content_mapping.py +++ b/pilot/language/lang_content_mapping.py @@ -29,7 +29,6 @@ lang_dicts = { "url_input_label": "输入网页地址", "add_as_new_klg": "添加为新知识库", "add_file_to_klg": "向知识库中添加文件", - "upload_file": "上传文件", "add_file": "添加文件", "upload_and_load_to_klg": "上传并加载到知识库", diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index dcc4a88bc..1a44678fe 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -9,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @torch.inference_mode() def chatglm_generate_stream( - model, tokenizer, params, device, context_len=2048, stream_interval=2 + model, tokenizer, params, device, context_len=2048, stream_interval=2 ): """Generate text using chatglm model's chat api""" prompt = params["prompt"] @@ -57,7 +57,7 @@ def chatglm_generate_stream( # i = 0 for i, (response, new_hist) in enumerate( - model.stream_chat(tokenizer, query, hist, **generate_kwargs) + model.stream_chat(tokenizer, query, hist, **generate_kwargs) ): if echo: output = query + " " + response diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 6c209b565..bf74ca766 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -4,6 +4,7 @@ from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER + def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" @@ -43,15 +44,20 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): input_ids = tokenizer(query, return_tensors="pt").input_ids input_ids = input_ids.to(model.device) - streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) stop_token_ids = [0] + class StopOnTokens(StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False - + stop = StopOnTokens() generate_kwargs = dict( @@ -59,17 +65,16 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): max_new_tokens=512, temperature=1.0, do_sample=True, - top_k=1, + top_k=1, streamer=streamer, repetition_penalty=1.7, - stopping_criteria=StoppingCriteriaList([stop]) + stopping_criteria=StoppingCriteriaList([stop]), ) - t1 = Thread(target=model.generate, kwargs=generate_kwargs) t1.start() - generator = model.generate(**generate_kwargs) + generator = model.generate(**generate_kwargs) for output in generator: # new_tokens = len(output) - len(input_ids[0]) decoded_output = tokenizer.decode(output) @@ -79,4 +84,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): out = decoded_output.split("### Response:")[-1].strip() yield out - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 663b87a7d..34b2bc8c7 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -53,7 +53,6 @@ class BaseOutputParser(ABC): """ if data["error_code"] == 0: if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]: - output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() @@ -65,8 +64,7 @@ class BaseOutputParser(ABC): return output # TODO 后续和模型绑定 - def parse_model_stream_resp(self, response, skip_echo_len): - + def parse_model_stream_resp(self, response, skip_echo_len): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) @@ -74,7 +72,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: + if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() @@ -82,9 +80,7 @@ class BaseOutputParser(ABC): output = self.__post_process_code(output) yield output else: - output = ( - data["text"] + f" (error_code: {data['error_code']})" - ) + output = data["text"] + f" (error_code: {data['error_code']})" yield output def parse_model_nostream_resp(self, response, sep: str): @@ -114,7 +110,6 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_prompt_response(self, model_out_text) -> T: """ parse model out text to prompt define response @@ -130,9 +125,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] + cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] + cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -144,7 +139,13 @@ class BaseOutputParser(ABC): cleaned_output = m.group(0) else: raise ValueError("model server out not fllow the prompt!") - cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') + cleaned_output = ( + cleaned_output.strip() + .replace("\n", "") + .replace("\\n", "") + .replace("\\", "") + .replace("\\", "") + ) return cleaned_output def parse_view_response(self, ai_text, data) -> str: diff --git a/pilot/prompts/generator.py b/pilot/prompts/generator.py index 22f998a67..6752cd1e1 100644 --- a/pilot/prompts/generator.py +++ b/pilot/prompts/generator.py @@ -133,10 +133,8 @@ class PluginPromptGenerator: else: return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) - - def generate_commands_string(self)->str: - return f"{self._generate_numbered_list(self.commands, item_type='command')}" - + def generate_commands_string(self) -> str: + return f"{self._generate_numbered_list(self.commands, item_type='command')}" def generate_prompt_string(self) -> str: """ diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 888b6f81e..6f50895fa 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -33,13 +33,13 @@ class PromptTemplate(BaseModel, ABC): """A list of the names of the variables the prompt template expects.""" template_scene: Optional[str] - template_define: Optional[str] + template_define: Optional[str] """this template define""" - template: Optional[str] + template: Optional[str] """The prompt template.""" template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - response_format: Optional[str] + response_format: Optional[str] """default use stream out""" stream_out: bool = True """""" @@ -62,7 +62,9 @@ class PromptTemplate(BaseModel, ABC): if self.template: if self.response_format: kwargs["response"] = json.dumps(self.response_format, indent=4) - return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + return DEFAULT_FORMATTER_MAPPING[self.template_format]( + self.template, **kwargs + ) def add_goals(self, goal: str) -> None: self.goals.append(goal) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 8c8dba501..c1d831d0d 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -57,7 +57,14 @@ class BaseChat(ABC): arbitrary_types_allowed = True - def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): + def __init__( + self, + temperature, + max_new_tokens, + chat_mode, + chat_session_id, + current_user_input, + ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode self.current_user_input: str = current_user_input @@ -68,7 +75,9 @@ class BaseChat(ABC): ## TEST self.memory = FileHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] + self.prompt_template: PromptTemplate = CFG.prompt_templates[ + self.chat_mode.value + ] self.history_message: List[OnceConversation] = [] self.current_message: OnceConversation = OnceConversation() self.current_tokens_used: int = 0 @@ -129,7 +138,7 @@ class BaseChat(ABC): def stream_call(self): payload = self.__call_base() - self.skip_echo_len = len(payload.get('prompt').replace("", " ")) + 11 + self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" try: @@ -141,7 +150,7 @@ class BaseChat(ABC): stream=True, timeout=120, ) - return response; + return response # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) @@ -175,29 +184,37 @@ class BaseChat(ABC): ### output parse ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep) + self.prompt_template.output_parser.parse_model_nostream_resp( + response, self.prompt_template.sep + ) ) self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_with_prompt_response(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): - if isinstance(prompt_define_response.thoughts, dict): + if isinstance(prompt_define_response.thoughts, dict): if "speak" in prompt_define_response.thoughts: speak_to_user = prompt_define_response.thoughts.get("speak") else: speak_to_user = str(prompt_define_response.thoughts) else: - if hasattr(prompt_define_response.thoughts, "speak"): + if hasattr(prompt_define_response.thoughts, "speak"): speak_to_user = prompt_define_response.thoughts.get("speak") - elif hasattr(prompt_define_response.thoughts, "reasoning"): + elif hasattr(prompt_define_response.thoughts, "reasoning"): speak_to_user = prompt_define_response.thoughts.get("reasoning") else: speak_to_user = prompt_define_response.thoughts else: speak_to_user = prompt_define_response - view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) + view_message = self.prompt_template.output_parser.parse_view_response( + speak_to_user, result + ) self.current_message.add_view_message(view_message) except Exception as e: print(traceback.format_exc()) @@ -226,20 +243,20 @@ class BaseChat(ABC): for first_message in self.history_message[0].messages: if not isinstance(first_message, ViewMessage): text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep + first_message.type + + ":" + + first_message.content + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for last_message in self.history_message[-index:].messages: if not isinstance(last_message, ViewMessage): text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep + last_message.type + + ":" + + last_message.content + + self.prompt_template.sep ) else: @@ -248,16 +265,16 @@ class BaseChat(ABC): for message in conversation.messages: if not isinstance(message, ViewMessage): text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep + message.type + + ":" + + message.content + + self.prompt_template.sep ) ### current conversation for now_message in self.current_message.messages: text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep + now_message.type + ":" + now_message.content + self.prompt_template.sep ) return text @@ -288,4 +305,3 @@ class BaseChat(ABC): """ pass - diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 2882fb1cc..1f4597789 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -21,15 +21,21 @@ class ChatWithDbAutoExecute(BaseChat): """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + def __init__( + self, temperature, max_new_tokens, chat_session_id, db_name, user_input + ): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatWithDbExecute, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbExecute, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) if not db_name: - raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!") + raise ValueError( + f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" + ) self.db_name = db_name self.database = CFG.local_db # 准备DB信息(拿到指定库的链接) @@ -40,9 +46,7 @@ class ChatWithDbAutoExecute(BaseChat): try: from pilot.summary.db_summary_client import DBSummaryClient except ImportError: - raise ValueError( - "Could not import DBSummaryClient. " - ) + raise ValueError("Could not import DBSummaryClient. ") input_values = { "input": self.current_user_input, "top_k": str(self.top_k), diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 66b3520fd..94cc6ea9e 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -20,9 +20,8 @@ class DbChatOutputParser(BaseOutputParser): def __init__(self, sep: str, is_stream_out: bool): super().__init__(sep=sep, is_stream_out=is_stream_out) - def parse_prompt_response(self, model_out_text): - clean_str = super().parse_prompt_response(model_out_text); + clean_str = super().parse_prompt_response(model_out_text) print("clean prompt response:", clean_str) response = json.loads(clean_str) sql, thoughts = response["sql"], response["thoughts"] diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 796d61f97..938860aaf 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -62,4 +62,3 @@ prompt = PromptTemplate( ), ) CFG.prompt_templates.update({prompt.template_scene: prompt}) - diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 74f83ddaa..66b751533 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -19,13 +19,17 @@ class ChatWithDbQA(BaseChat): """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + def __init__( + self, temperature, max_new_tokens, chat_session_id, db_name, user_input + ): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatWithDbQA, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbQA, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) self.db_name = db_name if db_name: self.database = CFG.local_db @@ -34,17 +38,16 @@ class ChatWithDbQA(BaseChat): self.top_k: int = 5 def generate_input_values(self): - table_info = "" dialect = "mysql" try: from pilot.summary.db_summary_client import DBSummaryClient except ImportError: - raise ValueError( - "Could not import DBSummaryClient. " - ) + raise ValueError("Could not import DBSummaryClient. ") if self.db_name: - table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + table_info = DBSummaryClient.get_similar_tables( + dbname=self.db_name, query=self.current_user_input, topk=self.top_k + ) # table_info = self.database.table_simple_info(self.db_connect) dialect = self.database.dialect @@ -52,7 +55,7 @@ class ChatWithDbQA(BaseChat): "input": self.current_user_input, "top_k": str(self.top_k), "dialect": dialect, - "table_info": table_info + "table_info": table_info, } return input_values diff --git a/pilot/scene/chat_db/professional_qa/out_parser.py b/pilot/scene/chat_db/professional_qa/out_parser.py index 0b8277d63..e5edc9b20 100644 --- a/pilot/scene/chat_db/professional_qa/out_parser.py +++ b/pilot/scene/chat_db/professional_qa/out_parser.py @@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): +class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text diff --git a/pilot/scene/chat_db/professional_qa/prompt.py b/pilot/scene/chat_db/professional_qa/prompt.py index 00fc87c03..9cc35b2e4 100644 --- a/pilot/scene/chat_db/professional_qa/prompt.py +++ b/pilot/scene/chat_db/professional_qa/prompt.py @@ -27,7 +27,6 @@ Pay attention to use only the column names that you can see in the schema descri """ - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = True @@ -37,7 +36,7 @@ prompt = PromptTemplate( input_variables=["input", "table_info", "dialect", "top_k"], response_format=None, template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX , + template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX, stream_out=PROMPT_NEED_NEED_STREAM_OUT, output_parser=NormalChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT @@ -45,4 +44,3 @@ prompt = PromptTemplate( ) CFG.prompt_templates.update({prompt.template_scene: prompt}) - diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 464df9ba0..e25a17340 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -14,56 +14,72 @@ from pilot.scene.chat_execution.prompt import prompt CFG = Config() + class ChatWithPlugin(BaseChat): chat_scene: str = ChatScene.ChatExecution.value - plugins_prompt_generator:PluginPromptGenerator + plugins_prompt_generator: PluginPromptGenerator select_plugin: str = None - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None): - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatExecution, - chat_session_id=chat_session_id, - current_user_input=user_input) + def __init__( + self, + temperature, + max_new_tokens, + chat_session_id, + user_input, + plugin_selector: str = None, + ): + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatExecution, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令 self.select_plugin = plugin_selector if self.select_plugin: for plugin in CFG.plugins: - if plugin._name == plugin_selector : + if plugin._name == plugin_selector: if not plugin.can_handle_post_prompt(): continue - self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) + self.plugins_prompt_generator = plugin.post_prompt( + self.plugins_prompt_generator + ) else: for plugin in CFG.plugins: if not plugin.can_handle_post_prompt(): continue - self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) - - - + self.plugins_prompt_generator = plugin.post_prompt( + self.plugins_prompt_generator + ) def generate_input_values(self): input_values = { "input": self.current_user_input, - "constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)), - "commands_infos": self.plugins_prompt_generator.generate_commands_string() + "constraints": self.__list_to_prompt_str( + list(self.plugins_prompt_generator.constraints) + ), + "commands_infos": self.plugins_prompt_generator.generate_commands_string(), } return input_values def do_with_prompt_response(self, prompt_response): ## plugin command run - return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator) + return execute_command( + str(prompt_response.command.get("name")), + prompt_response.command.get("args", {}), + self.plugins_prompt_generator, + ) def chat_show(self): super().chat_show() - def __list_to_prompt_str(self, list: List) -> str: if list: - separator = '\n' + separator = "\n" return separator.join(list) else: return "" diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index f9796ef3d..4d8989292 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -10,14 +10,13 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + class PluginAction(NamedTuple): command: Dict thoughts: Dict - class PluginChatOutputParser(BaseOutputParser): - def parse_prompt_response(self, model_out_text) -> T: response = json.loads(super().parse_prompt_response(model_out_text)) command, thoughts = response["command"], response["thoughts"] @@ -25,7 +24,7 @@ class PluginChatOutputParser(BaseOutputParser): def parse_view_response(self, speak, data) -> str: ### tool out data to table view - print(f"parse_view_response:{speak},{str(data)}" ) + print(f"parse_view_response:{speak},{str(data)}") view_text = f"##### {speak}" + "\n" + str(data) return view_text diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 6875689cf..2091a6469 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -53,7 +53,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, - input_variables=["input", "constraints", "commands_infos", "response"], + input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, @@ -63,4 +63,4 @@ prompt = PromptTemplate( ), ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) \ No newline at end of file +CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 2e67df66c..63d00b36c 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -11,6 +11,7 @@ from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary + class ChatFactory(metaclass=Singleton): @staticmethod def get_implementation(chat_mode, **kwargs): diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index f0db3df8f..a094b9d6f 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -1,4 +1,3 @@ - from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -24,18 +23,22 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding CFG = Config() -class ChatNewKnowledge (BaseChat): +class ChatNewKnowledge(BaseChat): chat_scene: str = ChatScene.ChatNewKnowledge.value """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name): + def __init__( + self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name + ): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatNewKnowledge, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNewKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) self.knowledge_name = knowledge_name vector_store_config = { "vector_store_name": knowledge_name, @@ -49,21 +52,17 @@ class ChatNewKnowledge (BaseChat): vector_store_config=vector_store_config, ) - def generate_input_values(self): - docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = self.knowledge_embedding_client.similar_search( + self.current_user_input, VECTOR_SEARCH_TOP_K + ) docs = docs[:2000] - input_values = { - "context": docs, - "question": self.current_user_input - } + input_values = {"context": docs, "question": self.current_user_input} return input_values def do_with_prompt_response(self, prompt_response): return prompt_response - - @property def chat_type(self) -> str: return ChatScene.ChatNewKnowledge.value diff --git a/pilot/scene/chat_knowledge/custom/out_parser.py b/pilot/scene/chat_knowledge/custom/out_parser.py index 0b8277d63..e5edc9b20 100644 --- a/pilot/scene/chat_knowledge/custom/out_parser.py +++ b/pilot/scene/chat_knowledge/custom/out_parser.py @@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): +class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index ab96c1703..c3153c819 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -23,7 +23,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用 """ - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = True @@ -42,5 +41,3 @@ prompt = PromptTemplate( CFG.prompt_templates.update({prompt.template_scene: prompt}) - - diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 978570d91..5d9c3ccf4 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -1,4 +1,3 @@ - from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -24,43 +23,42 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding CFG = Config() -class ChatDefaultKnowledge (BaseChat): +class ChatDefaultKnowledge(BaseChat): chat_scene: str = ChatScene.ChatKnowledge.value """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatKnowledge, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) vector_store_config = { "vector_store_name": "default", "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) def generate_input_values(self): - docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = self.knowledge_embedding_client.similar_search( + self.current_user_input, VECTOR_SEARCH_TOP_K + ) docs = docs[:2000] - input_values = { - "context": docs, - "question": self.current_user_input - } + input_values = {"context": docs, "question": self.current_user_input} return input_values def do_with_prompt_response(self, prompt_response): return prompt_response - - @property def chat_type(self) -> str: return ChatScene.ChatKnowledge.value diff --git a/pilot/scene/chat_knowledge/default/out_parser.py b/pilot/scene/chat_knowledge/default/out_parser.py index 0b8277d63..e5edc9b20 100644 --- a/pilot/scene/chat_knowledge/default/out_parser.py +++ b/pilot/scene/chat_knowledge/default/out_parser.py @@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): +class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index d2ba473ab..51d2419d5 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用 """ - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = True @@ -39,5 +38,3 @@ prompt = PromptTemplate( CFG.prompt_templates.update({prompt.template_scene: prompt}) - - diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index cbdc44538..e149f4a1b 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -1,4 +1,3 @@ - from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.configs.config import Config @@ -8,34 +7,42 @@ from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt CFG = Config() -class InnerChatDBSummary (BaseChat): +class InnerChatDBSummary(BaseChat): chat_scene: str = ChatScene.InnerChatDBSummary.value """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary): + def __init__( + self, + temperature, + max_new_tokens, + chat_session_id, + user_input, + db_select, + db_summary, + ): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.InnerChatDBSummary, - chat_session_id=chat_session_id, - current_user_input=user_input) - self.db_name = db_select - self.db_summary = db_summary + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.InnerChatDBSummary, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) + self.db_input = db_select + self.db_summary = db_summary def generate_input_values(self): input_values = { - "db_input": self.db_name, - "db_profile_summary": self.db_summary + "db_input": self.db_input, + "db_profile_summary": self.db_summary, } return input_values def do_with_prompt_response(self, prompt_response): return prompt_response - - @property def chat_type(self) -> str: return ChatScene.InnerChatDBSummary.value diff --git a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py index b17571edd..5731857e8 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py @@ -10,13 +10,15 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): - def parse_prompt_response(self, model_out_text) -> T: - return model_out_text +class NormalChatOutputParser(BaseOutputParser): + def parse_prompt_response(self, model_out_text): + clean_str = super().parse_prompt_response(model_out_text) + print("clean prompt response:", clean_str) + return clean_str def parse_view_response(self, ai_text, data) -> str: - return ai_text["table"] + return ai_text def get_format_instructions(self) -> str: pass diff --git a/pilot/scene/chat_knowledge/inner_db_summary/prompt.py b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py index 739bf0364..d2c6b1913 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/prompt.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py @@ -7,33 +7,30 @@ from pilot.configs.config import Config from pilot.scene.base import ChatScene from pilot.common.schema import SeparatorStyle -from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser +from pilot.scene.chat_knowledge.inner_db_summary.out_parser import ( + NormalChatOutputParser, +) CFG = Config() -PROMPT_SCENE_DEFINE ="""""" +PROMPT_SCENE_DEFINE = """""" _DEFAULT_TEMPLATE = """ Based on the following known database information?, answer which tables are involved in the user input. Known database information:{db_profile_summary} Input:{db_input} You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads -The response format must be JSON, and the key of JSON must be "table". + """ PROMPT_RESPONSE = """You must respond in JSON format as following format: {response} - -Ensure the response is correct json and can be parsed by Python json.loads +The response format must be JSON, and the key of JSON must be "table". """ - -RESPONSE_FORMAT = { - "table": ["orders", "products"] - } - +RESPONSE_FORMAT = {"table": ["orders", "products"]} PROMPT_SEP = SeparatorStyle.SINGLE.value @@ -54,5 +51,3 @@ prompt = PromptTemplate( CFG.prompt_templates.update({prompt.template_scene: prompt}) - - diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 0666de9e1..096df92cb 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -1,4 +1,3 @@ - from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -24,18 +23,20 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding CFG = Config() -class ChatUrlKnowledge (BaseChat): +class ChatUrlKnowledge(BaseChat): chat_scene: str = ChatScene.ChatUrlKnowledge.value """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url): + def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatUrlKnowledge, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatUrlKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) self.url = url vector_store_config = { "vector_store_name": url, @@ -54,19 +55,16 @@ class ChatUrlKnowledge (BaseChat): self.knowledge_embedding_client.knowledge_embedding() def generate_input_values(self): - docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = self.knowledge_embedding_client.similar_search( + self.current_user_input, VECTOR_SEARCH_TOP_K + ) docs = docs[:2000] - input_values = { - "context": docs, - "question": self.current_user_input - } + input_values = {"context": docs, "question": self.current_user_input} return input_values def do_with_prompt_response(self, prompt_response): return prompt_response - - @property def chat_type(self) -> str: return ChatScene.ChatUrlKnowledge.value diff --git a/pilot/scene/chat_knowledge/url/out_parser.py b/pilot/scene/chat_knowledge/url/out_parser.py index 0b8277d63..e5edc9b20 100644 --- a/pilot/scene/chat_knowledge/url/out_parser.py +++ b/pilot/scene/chat_knowledge/url/out_parser.py @@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): +class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py index a5c1fe226..8eaafd61e 100644 --- a/pilot/scene/chat_knowledge/url/prompt.py +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用 """ - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = True @@ -39,5 +38,3 @@ prompt = PromptTemplate( CFG.prompt_templates.update({prompt.template_scene: prompt}) - - diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index edd6ac53c..f4ba94320 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -1,4 +1,3 @@ - from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -19,25 +18,23 @@ class ChatNormal(BaseChat): """Number of results to return from the query""" - def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): """ """ - super().__init__(temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatNormal, - chat_session_id=chat_session_id, - current_user_input=user_input) + super().__init__( + temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNormal, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) def generate_input_values(self): - input_values = { - "input": self.current_user_input - } + input_values = {"input": self.current_user_input} return input_values def do_with_prompt_response(self, prompt_response): return prompt_response - - @property def chat_type(self) -> str: return ChatScene.ChatNormal.value diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py index 0b8277d63..e5edc9b20 100644 --- a/pilot/scene/chat_normal/out_parser.py +++ b/pilot/scene/chat_normal/out_parser.py @@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") -class NormalChatOutputParser(BaseOutputParser): +class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index 7a11cf3ff..2ab387952 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -29,5 +29,3 @@ prompt = PromptTemplate( CFG.prompt_templates.update({prompt.template_scene: prompt}) - - diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 6b385d9da..b5df56ca5 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -85,9 +85,7 @@ add_knowledge_base_dialogue = get_lang_text( "knowledge_qa_type_add_knowledge_base_dialogue" ) -url_knowledge_dialogue = get_lang_text( - "knowledge_qa_type_url_knowledge_dialogue" -) +url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue") knowledge_qa_type_list = [ llm_native_dialogue, @@ -205,9 +203,9 @@ def post_process_code(code): def get_chat_mode(selected, param=None) -> ChatScene: - if chat_mode_title['chat_use_plugin'] == selected: + if chat_mode_title["chat_use_plugin"] == selected: return ChatScene.ChatExecution - elif chat_mode_title['knowledge_qa'] == selected: + elif chat_mode_title["knowledge_qa"] == selected: mode = param if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge @@ -232,14 +230,23 @@ def chatbot_callback(state, message): def http_bot( - state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, - knowledge_name + state, + selected, + temperature, + max_new_tokens, + plugin_selector, + mode, + sql_mode, + db_selector, + url_input, + knowledge_name, ): logger.info( - f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") - if chat_mode_title['knowledge_qa'] == selected: + f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}" + ) + if chat_mode_title["knowledge_qa"] == selected: scene: ChatScene = get_chat_mode(selected, mode) - elif chat_mode_title['chat_use_plugin'] == selected: + elif chat_mode_title["chat_use_plugin"] == selected: scene: ChatScene = get_chat_mode(selected) else: scene: ChatScene = get_chat_mode(selected, sql_mode) @@ -251,7 +258,7 @@ def http_bot( "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "db_name": db_selector, - "user_input": state.last_user_input + "user_input": state.last_user_input, } elif ChatScene.ChatWithDbQA == scene: chat_param = { @@ -289,7 +296,7 @@ def http_bot( "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "user_input": state.last_user_input, - "knowledge_name": knowledge_name + "knowledge_name": knowledge_name, } elif ChatScene.ChatUrlKnowledge == scene: chat_param = { @@ -297,7 +304,7 @@ def http_bot( "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "user_input": state.last_user_input, - "url": url_input + "url": url_input, } else: state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}" @@ -314,7 +321,11 @@ def http_bot( response = chat.stream_call() for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len) + state.messages[-1][ + -1 + ] = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 except Exception as e: print(traceback.format_exc()) @@ -323,8 +334,8 @@ def http_bot( block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -361,7 +372,7 @@ def build_single_model_ui(): gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Accordion( - get_lang_text("model_control_param"), open=False, visible=False + get_lang_text("model_control_param"), open=False, visible=False ) as parameter_row: temperature = gr.Slider( minimum=0.0, @@ -411,7 +422,7 @@ def build_single_model_ui(): get_lang_text("sql_generate_mode_none"), ], show_label=False, - value=get_lang_text("sql_generate_mode_none") + value=get_lang_text("sql_generate_mode_none"), ) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) @@ -428,15 +439,19 @@ def build_single_model_ui(): value="", interactive=True, show_label=True, - type="value" + type="value", ).style(container=False) - def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData + def plugin_change( + evt: gr.SelectData, + ): # SelectData is a subclass of EventData print(f"You selected {evt.value} at {evt.index} from {evt.target}") print(f"user plugin:{plugins_select_info().get(evt.value)}") return plugins_select_info().get(evt.value) - plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") + plugin_selected = gr.Textbox( + show_label=False, visible=False, placeholder="Selected" + ) plugin_selector.select(plugin_change, None, plugin_selected) tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA") @@ -456,7 +471,12 @@ def build_single_model_ui(): ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False) + url_input = gr.Textbox( + label=get_lang_text("url_input_label"), + lines=1, + interactive=True, + visible=False, + ) def show_url_input(evt: gr.SelectData): if evt.value == url_knowledge_dialogue: @@ -559,10 +579,10 @@ def build_single_model_ui(): def build_webdemo(): with gr.Blocks( - title=get_lang_text("database_smart_assistant"), - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title=get_lang_text("database_smart_assistant"), + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) ( diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index aefddd848..bb5331434 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -18,7 +18,14 @@ CFG = Config() class KnowledgeEmbedding: - def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"): + def __init__( + self, + file_path, + model_name, + vector_store_config, + local_persist=True, + file_type="default", + ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.model_name = model_name @@ -63,7 +70,6 @@ class KnowledgeEmbedding: vector_store_config=self.vector_store_config, ) - elif self.file_type == "default": embedding = MarkdownEmbedding( file_path=self.file_path, @@ -71,7 +77,6 @@ class KnowledgeEmbedding: vector_store_config=self.vector_store_config, ) - return embedding def similar_search(self, text, topk): diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index a34b87a93..91805ddd4 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -13,6 +13,7 @@ from pilot.summary.mysql_db_summary import MysqlSummary from pilot.scene.chat_factory import ChatFactory CFG = Config() +chat_factory = ChatFactory() class DBSummaryClient: @@ -88,13 +89,18 @@ class DBSummaryClient: ) if CFG.SUMMARY_CONFIG == "FAST": table_docs = knowledge_embedding_client.similar_search(query, topk) - related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs] + related_tables = [ + json.loads(table_doc.page_content)["table_name"] + for table_doc in table_docs + ] else: table_docs = knowledge_embedding_client.similar_search(query, 1) # prompt = KnownLedgeBaseQA.build_db_summary_prompt( # query, table_docs[0].page_content # ) - related_tables = _get_llm_response(query, dbname, table_docs[0].page_content) + related_tables = _get_llm_response( + query, dbname, table_docs[0].page_content + ) related_table_summaries = [] for table in related_tables: vector_store_config = { @@ -118,35 +124,14 @@ def _get_llm_response(query, db_input, dbsummary): "max_new_tokens": 512, "chat_session_id": uuid.uuid1(), "user_input": query, - "db_input": db_input, + "db_select": db_input, "db_summary": dbsummary, } - chat_factory = ChatFactory() - chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param) - - return chat.call() - # payload = { - # "model": CFG.LLM_MODEL, - # "prompt": prompt, - # "temperature": float(0.7), - # "max_new_tokens": int(512), - # "stop": state.sep - # if state.sep_style == SeparatorStyle.SINGLE - # else state.sep2, - # } - # headers = {"User-Agent": "dbgpt Client"} - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # - # print(related_tables) - # return related_tables - # except NotCommands as e: - # print("llm response error:" + e.message) - + chat: BaseChat = chat_factory.get_implementation( + ChatScene.InnerChatDBSummary.value, **chat_param + ) + res = chat.nostream_call() + return json.loads(res)["table"] # if __name__ == "__main__": diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py index 3ed9b9171..a50b24f94 100644 --- a/pilot/summary/mysql_db_summary.py +++ b/pilot/summary/mysql_db_summary.py @@ -37,20 +37,23 @@ class MysqlSummary(DBSummary): table_name=table_comment[0], table_comment=table_comment[1] ) ) + vector_table = json.dumps( {"table_name": table_comment[0], "table_description": table_comment[1]} ) self.vector_tables_info.append( vector_table.encode("utf-8").decode("unicode_escape") ) - + self.table_columns_info = [] for table_name in tables: table_summary = MysqlTableSummary(self.db, name, table_name) - self.tables[table_name] = table_summary.get_summery() + # self.tables[table_name] = table_summary.get_summery() + self.tables[table_name] = table_summary.get_columns() + self.table_columns_info.append(table_summary.get_columns()) # self.tables_info.append(table_summary.get_summery()) def get_summery(self): - if CFG.SUMMARY_CONFIG == "VECTOR": + if CFG.SUMMARY_CONFIG == "FAST": return self.vector_tables_info else: return self.summery.format( @@ -63,6 +66,9 @@ class MysqlSummary(DBSummary): def get_table_comments(self): return self.table_comments + def get_columns(self): + return self.table_columns_info + class MysqlTableSummary(TableSummary): """Get mysql table summary template.""" @@ -78,10 +84,16 @@ class MysqlTableSummary(TableSummary): self.db = instance fields = self.db.get_fields(name) indexes = self.db.get_indexes(name) + field_names = [] for field in fields: field_summary = MysqlFieldsSummary(field) self.fields.append(field_summary) self.fields_info.append(field_summary.get_summery()) + field_names.append(field[0]) + + self.column_summery = """{name}({columns_info})""".format( + name=name, columns_info=",".join(field_names) + ) for index in indexes: index_summary = MysqlIndexSummary(index) @@ -96,6 +108,9 @@ class MysqlTableSummary(TableSummary): indexes=";".join(self.indexes_info), ) + def get_columns(self): + return self.column_summery + class MysqlFieldsSummary(FieldSummary): """Get mysql field summary template.""" diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 6c7028856..6672d3d23 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,4 +1,5 @@ from pilot.vector_store.chroma_store import ChromaStore + # from pilot.vector_store.milvus_store import MilvusStore connector = {"Chroma": ChromaStore, "Milvus": None}