diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index 09931c90e..40646c309 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config): print("load_native_plugins") ### TODO 默认拉主分支,后续拉发布版本 branch_name = cfg.plugins_git_branch - native_plugin_repo ="DB-GPT-Plugins" + native_plugin_repo = "DB-GPT-Plugins" url = "https://github.com/csunny/{repo}/archive/{branch}.zip" - response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name), - headers={'Authorization': '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'}) + response = requests.get( + url.format(repo=native_plugin_repo, branch=branch_name), + headers={"Authorization": "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"}, + ) if response.status_code == 200: plugins_path_path = Path(PLUGINS_DIR) - files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) + files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*")) for file in files: os.remove(file) now = datetime.datetime.now() - time_str = now.strftime('%Y%m%d%H%M%S') - file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" + time_str = now.strftime("%Y%m%d%H%M%S") + file_name = ( + f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" + ) print(file_name) with open(file_name, "wb") as f: f.write(response.content) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 5ccfb7902..d59a9d33f 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -66,7 +66,6 @@ class Database: self._sample_rows_in_table_info = set() self._indexes_in_table_info = indexes_in_table_info - @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -399,7 +398,6 @@ class Database: ans = cursor.fetchall() return ans[0][1] - def get_fields(self, table_name): """Get column fields about specified table.""" session = self._db_sessions() diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 4bda464a7..0dc78af06 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path -PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") -FONT_DIR = os.path.join(PILOT_PATH, "fonts") +PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") +FONT_DIR = os.path.join(PILOT_PATH, "fonts") current_directory = os.getcwd() diff --git a/pilot/connections/rdbms/clickhouse.py b/pilot/connections/rdbms/clickhouse.py index c7421e8e6..3e243759d 100644 --- a/pilot/connections/rdbms/clickhouse.py +++ b/pilot/connections/rdbms/clickhouse.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config CFG = Config() + class ClickHouseConnector(RDBMSDatabase): """ClickHouseConnector""" @@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase): default_db = ["information_schema", "performance_schema", "sys", "mysql"] - @classmethod def from_config(cls) -> RDBMSDatabase: """ Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_PATH, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_PATH, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, db_path: str, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: + def from_uri_db( + cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: db_url: str = cls.connect_driver + "://" + db_path return cls.from_uri(db_url, engine_args, **kwargs) diff --git a/pilot/connections/rdbms/duckdb.py b/pilot/connections/rdbms/duckdb.py index e8b1038cb..947807744 100644 --- a/pilot/connections/rdbms/duckdb.py +++ b/pilot/connections/rdbms/duckdb.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config CFG = Config() + class DuckDbConnect(RDBMSDatabase): """Connect Duckdb Database fetch MetaData Args: @@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase): default_db = ["information_schema", "performance_schema", "sys", "mysql"] - @classmethod def from_config(cls) -> RDBMSDatabase: """ Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_PATH, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_PATH, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, db_path: str, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: + def from_uri_db( + cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: db_url: str = cls.connect_driver + "://" + db_path return cls.from_uri(db_url, engine_args, **kwargs) diff --git a/pilot/connections/rdbms/mssql.py b/pilot/connections/rdbms/mssql.py index 89c37e757..ceab845c4 100644 --- a/pilot/connections/rdbms/mssql.py +++ b/pilot/connections/rdbms/mssql.py @@ -5,9 +5,6 @@ from typing import Optional, Any from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - - class MSSQLConnect(RDBMSDatabase): """Connect MSSQL Database fetch MetaData Args: @@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase): dialect: str = "mssql" driver: str = "pyodbc" - default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"] - - + default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"] diff --git a/pilot/connections/rdbms/mysql.py b/pilot/connections/rdbms/mysql.py index c1b57f784..8acf90759 100644 --- a/pilot/connections/rdbms/mysql.py +++ b/pilot/connections/rdbms/mysql.py @@ -5,9 +5,6 @@ from typing import Optional, Any from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - - class MySQLConnect(RDBMSDatabase): """Connect MySQL Database fetch MetaData Args: @@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase): driver: str = "pymysql" default_db = ["information_schema", "performance_schema", "sys", "mysql"] - - diff --git a/pilot/connections/rdbms/oracle.py b/pilot/connections/rdbms/oracle.py index 8c5c0d004..8959695b0 100644 --- a/pilot/connections/rdbms/oracle.py +++ b/pilot/connections/rdbms/oracle.py @@ -2,8 +2,10 @@ # -*- coding:utf-8 -*- from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + class OracleConnector(RDBMSDatabase): """OracleConnector""" + type: str = "ORACLE" driver: str = "oracle" diff --git a/pilot/connections/rdbms/postgres.py b/pilot/connections/rdbms/postgres.py index 2d366566a..104380a37 100644 --- a/pilot/connections/rdbms/postgres.py +++ b/pilot/connections/rdbms/postgres.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + class PostgresConnector(RDBMSDatabase): """PostgresConnector is a class which Connector""" diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 68784f9b7..5a2b3edae 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -57,18 +57,19 @@ CFG = Config() if __name__ == "__main__": - def __extract_json(s): - i = s.index('{') - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): - if c == '}': - count -= 1 - elif c == '{': - count += 1 - if count == 0: - break - assert (count == 0) # 检查是否找到最后一个'}' - return s[i:j + 1] - ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" - print(__extract_json(ss)) \ No newline at end of file + def __extract_json(s): + i = s.index("{") + count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 + for j, c in enumerate(s[i + 1 :], start=i + 1): + if c == "}": + count -= 1 + elif c == "{": + count += 1 + if count == 0: + break + assert count == 0 # 检查是否找到最后一个'}' + return s[i : j + 1] + + ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" + print(__extract_json(ss)) diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py index 424bfaa7f..7fef1862f 100644 --- a/pilot/connections/rdbms/rdbms_connect.py +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect): """SQLAlchemy wrapper around a database.""" def __init__( - self, - engine, - schema: Optional[str] = None, - metadata: Optional[MetaData] = None, - ignore_tables: Optional[List[str]] = None, - include_tables: Optional[List[str]] = None, - + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine @@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect): Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_HOST, - CFG.LOCAL_DB_PORT, - CFG.LOCAL_DB_USER, - CFG.LOCAL_DB_PASSWORD, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_HOST, + CFG.LOCAL_DB_PORT, + CFG.LOCAL_DB_USER, + CFG.LOCAL_DB_PASSWORD, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: - db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str( - CFG.LOCAL_DB_PORT) + def from_uri_db( + cls, + host: str, + port: int, + user: str, + pwd: str, + db_name: str = None, + engine_args: Optional[dict] = None, + **kwargs: Any, + ) -> RDBMSDatabase: + db_url: str = ( + cls.connect_driver + + "://" + + CFG.LOCAL_DB_USER + + ":" + + CFG.LOCAL_DB_PASSWORD + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + ) if cls.dialect: db_url = cls.dialect + "+" + db_url if db_name: @@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect): @classmethod def from_uri( - cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> RDBMSDatabase: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} @@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect): tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) - and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] @@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect): create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( - self._indexes_in_table_info or self._sample_rows_in_table_info + self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 8e98ed4c9..4336d43e3 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) } ) - # Move the last user's information to the end + # Move the last user's information to the end temp_his = history[::-1] last_user_input = None for m in temp_his: @@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) text = "" for line in res.iter_lines(): if line: - json_data = line.split(b': ', 1)[1] + json_data = line.split(b": ", 1)[1] decoded_line = json_data.decode("utf-8") if decoded_line.lower() != "[DONE]".lower(): obj = json.loads(json_data) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 3b4c9e028..bd968aef1 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -121,17 +121,17 @@ class BaseOutputParser(ABC): raise ValueError("Model server error!code=" + respObj_ex["error_code"]) def __extract_json(slef, s): - i = s.index('{') + i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): - if c == '}': + for j, c in enumerate(s[i + 1 :], start=i + 1): + if c == "}": count -= 1 - elif c == '{': + elif c == "{": count += 1 if count == 0: break - assert (count == 0) # 检查是否找到最后一个'}' - return s[i:j + 1] + assert count == 0 # 检查是否找到最后一个'}' + return s[i : j + 1] def parse_prompt_response(self, model_out_text) -> T: """ diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 497b2cd10..0120b9e86 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -134,7 +134,6 @@ class BaseChat(ABC): return payload def stream_call(self): - # TODO Retry when server connection error payload = self.__call_base() @@ -189,19 +188,19 @@ class BaseChat(ABC): ) ) -# ### MOCK -# ai_response_text = """{ -# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", -# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", -# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", -# "command": { -# "name": "histogram-executor", -# "args": { -# "title": "订单城市分布柱状图", -# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" -# } -# } -# }""" + # ### MOCK + # ai_response_text = """{ + # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", + # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", + # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", + # "command": { + # "name": "histogram-executor", + # "args": { + # "title": "订单城市分布柱状图", + # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" + # } + # } + # }""" self.current_message.add_ai_message(ai_response_text) prompt_define_response = ( diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index f91af967c..1dcb4c6ed 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat): def __list_to_prompt_str(self, list: List) -> str: return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) - def generate(self, p) -> str: return super().generate(p) diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index 44f203d1e..565d54c5e 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser): command, thoughts, speak = ( response["command"], response["thoughts"], - response["speak"] + response["speak"], ) return PluginAction(command, speak, thoughts) diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 3f21b828d..6116deecd 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat): context = context[:2000] input_values = {"context": context, "question": self.current_user_input} except NoIndexException: - raise ValueError("you have no default knowledge store, please execute python knowledge_init.py") + raise ValueError( + "you have no default knowledge store, please execute python knowledge_init.py" + ) return input_values def do_with_prompt_response(self, prompt_response): diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index 55f525988..ac72fc637 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -5,7 +5,6 @@ import sys from dotenv import load_dotenv - if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): print("Setting random seed to 42") random.seed(42) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index beab61d4a..1e3a4dcb3 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -87,7 +87,10 @@ class ModelWorker: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" except Exception as e: - ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0} + ret = { + "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + "error_code": 0, + } yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index e76865550..761a239e7 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -667,8 +667,8 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") - - # init config + + # init config cfg = Config() load_native_plugins(cfg) @@ -682,7 +682,7 @@ if __name__ == "__main__": "pilot.commands.built_in.audio_text", "pilot.commands.built_in.image_gen", ] - # exclude commands + # exclude commands command_categories = [ x for x in command_categories if x not in cfg.disabled_command_categories ] diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 5f6d9526d..60046d0cd 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, + ) return loader.load_and_split(textsplitter) @register diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 66b0963d9..87ad9d1cf 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding): # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, ) return loader.load_and_split(textsplitter) diff --git a/pilot/source_embedding/ppt_embedding.py b/pilot/source_embedding/ppt_embedding.py index 869e92395..583b29ed1 100644 --- a/pilot/source_embedding/ppt_embedding.py +++ b/pilot/source_embedding/ppt_embedding.py @@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding): def read(self): """Load from ppt path.""" loader = UnstructuredPowerPointLoader(self.file_path) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, + ) return loader.load_and_split(textsplitter) @register diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 84fbf1550..5e551514b 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -78,7 +78,7 @@ class DBSummaryClient: model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) - table_docs =knowledge_embedding_client.similar_search(query, topk) + table_docs = knowledge_embedding_client.similar_search(query, topk) ans = [d.page_content for d in table_docs] return ans @@ -147,8 +147,6 @@ class DBSummaryClient: logger.info("init db profile success...") - - def _get_llm_response(query, db_input, dbsummary): chat_param = { "temperature": 0.7, diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py index 4a578fe2c..08a01c0fc 100644 --- a/pilot/summary/mysql_db_summary.py +++ b/pilot/summary/mysql_db_summary.py @@ -43,15 +43,14 @@ CFG = Config() # "tps": 50 # } + class MysqlSummary(DBSummary): """Get mysql summary template.""" def __init__(self, name): self.name = name self.type = "MYSQL" - self.summery = ( - """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" - ) + self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.tables = {} self.tables_info = [] self.vector_tables_info = [] @@ -92,9 +91,12 @@ class MysqlSummary(DBSummary): self.tables[table_name] = table_summary.get_columns() self.table_columns_info.append(table_summary.get_columns()) # self.table_columns_json.append(table_summary.get_summary_json()) - table_profile = "table name:{table_name},table description:{table_comment}".format( - table_name=table_name, table_comment=self.db.get_show_create_table(table_name) + table_profile = ( + "table name:{table_name},table description:{table_comment}".format( + table_name=table_name, + table_comment=self.db.get_show_create_table(table_name), ) + ) self.table_columns_json.append(table_profile) # self.tables_info.append(table_summary.get_summery()) @@ -108,7 +110,11 @@ class MysqlSummary(DBSummary): def get_db_summery(self): return self.summery.format( - name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 + name=self.name, + type=self.type, + tables=";".join(self.vector_tables_info), + qps=1000, + tps=1000, ) def get_table_summary(self): @@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary): self.indexes_info.append(index_summary.get_summery()) self.json_summery = self.json_summery_template.format( - name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 + name=name, + comment=comment_map[name], + fields=self.fields_info, + indexes=self.indexes_info, + size_in_bytes=1000, + rows=1000, ) def get_summery(self): @@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary): self.bind_fields = index[1] def get_summery(self): - return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) + return self.summery_template.format( + name=self.name, bind_fields=self.bind_fields + ) if __name__ == "__main__":