From 545c3232161b11a1e6dfccaa3bb393718a745bc5 Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 11:35:06 +0800 Subject: [PATCH 01/10] add gpt4all --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 23 +++++++++++++++++------ pilot/model/llm_out/gpt4all_llm.py | 17 +++++++++++++++++ pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 15 ++++++++++++--- pilot/server/llmserver.py | 5 +++-- 6 files changed, 51 insertions(+), 12 deletions(-) create mode 100644 pilot/model/llm_out/gpt4all_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 36d615043..b85fe6b7b 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), + "ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 7892e4b1b..89ea55ec2 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import torch +import os +from functools import cache from typing import List from functools import cache from transformers import ( @@ -92,8 +94,8 @@ class ChatGLMAdapater(BaseLLMAdaper): AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) - .half() - .cuda() + .half() + .cuda() ) return model, tokenizer @@ -185,18 +187,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): - """A light version for someone who want practise LLM use laptop.""" + """ + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json + """ def match(self, model_path: str): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO - pass + import gpt4all + + if model_path is None and from_pretrained_kwargs.get('model_name') is None: + model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") + else: + path, file = os.path.split(model_path) + model = gpt4all.GPT4All(model_path=path, model_name=file) + return model, None class ProxyllmAdapter(BaseLLMAdaper): - """The model adapter for local proxy""" def match(self, model_path: str): @@ -211,6 +221,7 @@ register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) +register_llm_model_adapters(GPT4AllAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py new file mode 100644 index 000000000..4cc1f067f --- /dev/null +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): + stop = params.get("stop", "###") + prompt = params["prompt"] + role, query = prompt.split(stop)[1].split(":") + print(f"gpt4all, role: {role}, query: {query}") + + messages = [{"role": "user", "content": query}] + res = model.chat_completion(messages) + if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ + res.get('choices')[0].get('message').get('content'): + yield res.get('choices')[0].get('message').get('content') + else: + yield "error response" + diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 513c1d300..6f08d93fe 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data["error_code"] == 0: + if data.get('error_code', 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index e4f57cf46..3598b16b3 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: class VicunaChatAdapter(BaseChatAdpter): - """Model chat Adapter for vicuna""" def match(self, model_path: str): @@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeT5""" def match(self, model_path: str): @@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeGen""" def match(self, model_path: str): @@ -127,11 +124,23 @@ class GorillaChatAdapter(BaseChatAdpter): return generate_stream +class GPT4AllChatAdapter(BaseChatAdpter): + + def match(self, model_path: str): + return "gpt4all" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream + + return gpt4all_generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) +register_llm_model_chat_adapter(GPT4AllChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d2730e0d5..e71872d64 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,9 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): self.context_len = self.model.config.max_position_embeddings else: @@ -66,6 +66,7 @@ class ModelWorker: def generate_stream_gate(self, params): try: + print(f"llmserver params: {params}, self: {self}") for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): From cacab625cf1b3ea09f83233816d7dd996a24919a Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 11:39:51 +0800 Subject: [PATCH 02/10] rm log --- pilot/model/adapter.py | 1 - pilot/server/llmserver.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 89ea55ec2..435be1142 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,7 +3,6 @@ import torch import os -from functools import cache from typing import List from functools import cache from transformers import ( diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index e71872d64..ad4627afa 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -66,7 +66,6 @@ class ModelWorker: def generate_stream_gate(self, params): try: - print(f"llmserver params: {params}, self: {self}") for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): From 7136aa748dce63c2f1235d9606c2686d0437cc0f Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 13:57:47 +0800 Subject: [PATCH 03/10] fix model key --- pilot/configs/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index b85fe6b7b..851a0486d 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,7 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), - "ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), + "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } From fc4fd6aa773df80d4234629156d2ad3750851024 Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 14:05:49 +0800 Subject: [PATCH 04/10] requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9238751ca..c6434c3ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ llama-index==0.5.27 pymysql unstructured==0.6.3 grpcio==1.47.5 +gpt4all==0.3.0 auto-gpt-plugin-template pymdown-extensions From a88e8aa51bf27bf5540655d73bd579cf7cfae7bd Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:11:23 +0800 Subject: [PATCH 05/10] update stream output --- pilot/model/adapter.py | 6 +++--- pilot/model/llm_out/gpt4all_llm.py | 20 +++++++++++++------- pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 1 - pilot/server/llmserver.py | 11 ++++++++--- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 435be1142..407d11127 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -187,8 +187,8 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): """ - A light version for someone who want practise LLM use laptop. - All model names see: https://gpt4all.io/models/models.json + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json """ def match(self, model_path: str): @@ -197,7 +197,7 @@ class GPT4AllAdapter(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwargs: dict): import gpt4all - if model_path is None and from_pretrained_kwargs.get('model_name') is None: + if model_path is None and from_pretrained_kwargs.get("model_name") is None: model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") else: path, file = os.path.split(model_path) diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 4cc1f067f..5ea72f911 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import threading +import sys +import time + def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): stop = params.get("stop", "###") @@ -7,11 +11,13 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed role, query = prompt.split(stop)[1].split(":") print(f"gpt4all, role: {role}, query: {query}") - messages = [{"role": "user", "content": query}] - res = model.chat_completion(messages) - if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ - res.get('choices')[0].get('message').get('content'): - yield res.get('choices')[0].get('message').get('content') - else: - yield "error response" + def worker(): + model.generate(prompt=query, streaming=True) + t = threading.Thread(target=worker) + t.start() + + while t.is_alive(): + yield sys.stdout.output + time.sleep(0.1) + t.join() diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 6f08d93fe..6406f30dd 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data.get('error_code', 0) == 0: + if data.get("error_code", 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 3598b16b3..ebab2d2d4 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -125,7 +125,6 @@ class GorillaChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter): - def match(self, model_path: str): return "gpt4all" in model_path diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index ad4627afa..66180a406 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,13 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr( + self.model.config, "max_sequence_length" + ): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr( + self.model.config, "max_position_embeddings" + ): self.context_len = self.model.config.max_position_embeddings else: @@ -69,7 +73,8 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - print("output: ", output) + # 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出 + # print("output: ", output) ret = { "text": output, "error_code": 0, From 5da4b38964c8662f61e2e409cf1476ae4839d346 Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:22:55 +0800 Subject: [PATCH 06/10] format code --- pilot/model/adapter.py | 4 ++-- pilot/server/llmserver.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 407d11127..01d05837b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -93,8 +93,8 @@ class ChatGLMAdapater(BaseLLMAdaper): AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) - .half() - .cuda() + .half() + .cuda() ) return model, tokenizer diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 66180a406..30653a16e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -73,7 +73,9 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - # 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出 + # Please do not open the output in production! + # The gpt4all thread shares stdout with the parent process, + # and opening it may affect the frontend output. # print("output: ", output) ret = { "text": output, From cc3c2d779936e471ca76e85424c0754cd040c90f Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:35:14 +0800 Subject: [PATCH 07/10] update sleep interval --- pilot/model/llm_out/gpt4all_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 5ea72f911..7a39a8012 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -19,5 +19,5 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed while t.is_alive(): yield sys.stdout.output - time.sleep(0.1) + time.sleep(0.01) t.join() From 1f78c21c088e43fc81ba4f81245436e42267fdb0 Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 13 Jun 2023 19:46:03 +0800 Subject: [PATCH 08/10] debug: more detail error info --- pilot/server/llmserver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 30653a16e..003641807 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -86,6 +86,9 @@ class ModelWorker: except torch.cuda.CudaError: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0} + yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): return get_embeddings(self.model, self.tokenizer, prompt) From d839180d67b144cf00b4b53d2cd27d8625428828 Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 13 Jun 2023 21:17:03 +0800 Subject: [PATCH 09/10] Todo: add retry for generate --- pilot/scene/base_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index c1d831d0d..1eff3312c 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -136,6 +136,8 @@ class BaseChat(ABC): return payload def stream_call(self): + + # TODO Retry when server connection error payload = self.__call_base() self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 From 4f82cfde6344604462871d7bfdfc5e217db7bdbc Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 14 Jun 2023 10:17:53 +0800 Subject: [PATCH 10/10] pylint: multi model for gp4all (#138) --- pilot/common/plugins.py | 16 ++++-- pilot/common/sql_database.py | 2 - pilot/configs/model_config.py | 4 +- pilot/connections/rdbms/clickhouse.py | 15 +++-- pilot/connections/rdbms/duckdb.py | 15 +++-- pilot/connections/rdbms/mssql.py | 7 +-- pilot/connections/rdbms/mysql.py | 5 -- pilot/connections/rdbms/oracle.py | 2 + pilot/connections/rdbms/postgres.py | 1 + pilot/connections/rdbms/py_study/pd_study.py | 29 +++++----- pilot/connections/rdbms/rdbms_connect.py | 58 +++++++++++++------- pilot/model/llm_out/proxy_llm.py | 4 +- pilot/out_parser/base.py | 12 ++-- pilot/scene/base_chat.py | 27 +++++---- pilot/scene/chat_execution/chat.py | 1 - pilot/scene/chat_execution/out_parser.py | 2 +- pilot/scene/chat_knowledge/default/chat.py | 4 +- pilot/server/__init__.py | 1 - pilot/server/llmserver.py | 5 +- pilot/server/webserver.py | 6 +- pilot/source_embedding/markdown_embedding.py | 6 +- pilot/source_embedding/pdf_embedding.py | 4 +- pilot/source_embedding/ppt_embedding.py | 6 +- pilot/summary/db_summary_client.py | 4 +- pilot/summary/mysql_db_summary.py | 29 +++++++--- 25 files changed, 154 insertions(+), 111 deletions(-) diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index 09931c90e..40646c309 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config): print("load_native_plugins") ### TODO 默认拉主分支,后续拉发布版本 branch_name = cfg.plugins_git_branch - native_plugin_repo ="DB-GPT-Plugins" + native_plugin_repo = "DB-GPT-Plugins" url = "https://github.com/csunny/{repo}/archive/{branch}.zip" - response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name), - headers={'Authorization': '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'}) + response = requests.get( + url.format(repo=native_plugin_repo, branch=branch_name), + headers={"Authorization": "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"}, + ) if response.status_code == 200: plugins_path_path = Path(PLUGINS_DIR) - files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) + files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*")) for file in files: os.remove(file) now = datetime.datetime.now() - time_str = now.strftime('%Y%m%d%H%M%S') - file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" + time_str = now.strftime("%Y%m%d%H%M%S") + file_name = ( + f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" + ) print(file_name) with open(file_name, "wb") as f: f.write(response.content) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 5ccfb7902..d59a9d33f 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -66,7 +66,6 @@ class Database: self._sample_rows_in_table_info = set() self._indexes_in_table_info = indexes_in_table_info - @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -399,7 +398,6 @@ class Database: ans = cursor.fetchall() return ans[0][1] - def get_fields(self, table_name): """Get column fields about specified table.""" session = self._db_sessions() diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 4bda464a7..0dc78af06 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path -PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") -FONT_DIR = os.path.join(PILOT_PATH, "fonts") +PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") +FONT_DIR = os.path.join(PILOT_PATH, "fonts") current_directory = os.getcwd() diff --git a/pilot/connections/rdbms/clickhouse.py b/pilot/connections/rdbms/clickhouse.py index c7421e8e6..3e243759d 100644 --- a/pilot/connections/rdbms/clickhouse.py +++ b/pilot/connections/rdbms/clickhouse.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config CFG = Config() + class ClickHouseConnector(RDBMSDatabase): """ClickHouseConnector""" @@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase): default_db = ["information_schema", "performance_schema", "sys", "mysql"] - @classmethod def from_config(cls) -> RDBMSDatabase: """ Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_PATH, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_PATH, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, db_path: str, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: + def from_uri_db( + cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: db_url: str = cls.connect_driver + "://" + db_path return cls.from_uri(db_url, engine_args, **kwargs) diff --git a/pilot/connections/rdbms/duckdb.py b/pilot/connections/rdbms/duckdb.py index e8b1038cb..947807744 100644 --- a/pilot/connections/rdbms/duckdb.py +++ b/pilot/connections/rdbms/duckdb.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config CFG = Config() + class DuckDbConnect(RDBMSDatabase): """Connect Duckdb Database fetch MetaData Args: @@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase): default_db = ["information_schema", "performance_schema", "sys", "mysql"] - @classmethod def from_config(cls) -> RDBMSDatabase: """ Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_PATH, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_PATH, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, db_path: str, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: + def from_uri_db( + cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: db_url: str = cls.connect_driver + "://" + db_path return cls.from_uri(db_url, engine_args, **kwargs) diff --git a/pilot/connections/rdbms/mssql.py b/pilot/connections/rdbms/mssql.py index 89c37e757..ceab845c4 100644 --- a/pilot/connections/rdbms/mssql.py +++ b/pilot/connections/rdbms/mssql.py @@ -5,9 +5,6 @@ from typing import Optional, Any from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - - class MSSQLConnect(RDBMSDatabase): """Connect MSSQL Database fetch MetaData Args: @@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase): dialect: str = "mssql" driver: str = "pyodbc" - default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"] - - + default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"] diff --git a/pilot/connections/rdbms/mysql.py b/pilot/connections/rdbms/mysql.py index c1b57f784..8acf90759 100644 --- a/pilot/connections/rdbms/mysql.py +++ b/pilot/connections/rdbms/mysql.py @@ -5,9 +5,6 @@ from typing import Optional, Any from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - - class MySQLConnect(RDBMSDatabase): """Connect MySQL Database fetch MetaData Args: @@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase): driver: str = "pymysql" default_db = ["information_schema", "performance_schema", "sys", "mysql"] - - diff --git a/pilot/connections/rdbms/oracle.py b/pilot/connections/rdbms/oracle.py index 8c5c0d004..8959695b0 100644 --- a/pilot/connections/rdbms/oracle.py +++ b/pilot/connections/rdbms/oracle.py @@ -2,8 +2,10 @@ # -*- coding:utf-8 -*- from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + class OracleConnector(RDBMSDatabase): """OracleConnector""" + type: str = "ORACLE" driver: str = "oracle" diff --git a/pilot/connections/rdbms/postgres.py b/pilot/connections/rdbms/postgres.py index 2d366566a..104380a37 100644 --- a/pilot/connections/rdbms/postgres.py +++ b/pilot/connections/rdbms/postgres.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + class PostgresConnector(RDBMSDatabase): """PostgresConnector is a class which Connector""" diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 68784f9b7..5a2b3edae 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -57,18 +57,19 @@ CFG = Config() if __name__ == "__main__": - def __extract_json(s): - i = s.index('{') - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): - if c == '}': - count -= 1 - elif c == '{': - count += 1 - if count == 0: - break - assert (count == 0) # 检查是否找到最后一个'}' - return s[i:j + 1] - ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" - print(__extract_json(ss)) \ No newline at end of file + def __extract_json(s): + i = s.index("{") + count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 + for j, c in enumerate(s[i + 1 :], start=i + 1): + if c == "}": + count -= 1 + elif c == "{": + count += 1 + if count == 0: + break + assert count == 0 # 检查是否找到最后一个'}' + return s[i : j + 1] + + ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" + print(__extract_json(ss)) diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py index 424bfaa7f..7fef1862f 100644 --- a/pilot/connections/rdbms/rdbms_connect.py +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect): """SQLAlchemy wrapper around a database.""" def __init__( - self, - engine, - schema: Optional[str] = None, - metadata: Optional[MetaData] = None, - ignore_tables: Optional[List[str]] = None, - include_tables: Optional[List[str]] = None, - + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine @@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect): Todo password encryption Returns: """ - return cls.from_uri_db(cls, - CFG.LOCAL_DB_HOST, - CFG.LOCAL_DB_PORT, - CFG.LOCAL_DB_USER, - CFG.LOCAL_DB_PASSWORD, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + return cls.from_uri_db( + cls, + CFG.LOCAL_DB_HOST, + CFG.LOCAL_DB_PORT, + CFG.LOCAL_DB_USER, + CFG.LOCAL_DB_PASSWORD, + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) @classmethod - def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None, - engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: - db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str( - CFG.LOCAL_DB_PORT) + def from_uri_db( + cls, + host: str, + port: int, + user: str, + pwd: str, + db_name: str = None, + engine_args: Optional[dict] = None, + **kwargs: Any, + ) -> RDBMSDatabase: + db_url: str = ( + cls.connect_driver + + "://" + + CFG.LOCAL_DB_USER + + ":" + + CFG.LOCAL_DB_PASSWORD + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + ) if cls.dialect: db_url = cls.dialect + "+" + db_url if db_name: @@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect): @classmethod def from_uri( - cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> RDBMSDatabase: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} @@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect): tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) - and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] @@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect): create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( - self._indexes_in_table_info or self._sample_rows_in_table_info + self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 8e98ed4c9..4336d43e3 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) } ) - # Move the last user's information to the end + # Move the last user's information to the end temp_his = history[::-1] last_user_input = None for m in temp_his: @@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) text = "" for line in res.iter_lines(): if line: - json_data = line.split(b': ', 1)[1] + json_data = line.split(b": ", 1)[1] decoded_line = json_data.decode("utf-8") if decoded_line.lower() != "[DONE]".lower(): obj = json.loads(json_data) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 3b4c9e028..bd968aef1 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -121,17 +121,17 @@ class BaseOutputParser(ABC): raise ValueError("Model server error!code=" + respObj_ex["error_code"]) def __extract_json(slef, s): - i = s.index('{') + i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): - if c == '}': + for j, c in enumerate(s[i + 1 :], start=i + 1): + if c == "}": count -= 1 - elif c == '{': + elif c == "{": count += 1 if count == 0: break - assert (count == 0) # 检查是否找到最后一个'}' - return s[i:j + 1] + assert count == 0 # 检查是否找到最后一个'}' + return s[i : j + 1] def parse_prompt_response(self, model_out_text) -> T: """ diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 497b2cd10..0120b9e86 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -134,7 +134,6 @@ class BaseChat(ABC): return payload def stream_call(self): - # TODO Retry when server connection error payload = self.__call_base() @@ -189,19 +188,19 @@ class BaseChat(ABC): ) ) -# ### MOCK -# ai_response_text = """{ -# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", -# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", -# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", -# "command": { -# "name": "histogram-executor", -# "args": { -# "title": "订单城市分布柱状图", -# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" -# } -# } -# }""" + # ### MOCK + # ai_response_text = """{ + # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", + # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", + # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", + # "command": { + # "name": "histogram-executor", + # "args": { + # "title": "订单城市分布柱状图", + # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" + # } + # } + # }""" self.current_message.add_ai_message(ai_response_text) prompt_define_response = ( diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index f91af967c..1dcb4c6ed 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat): def __list_to_prompt_str(self, list: List) -> str: return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) - def generate(self, p) -> str: return super().generate(p) diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index 44f203d1e..565d54c5e 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser): command, thoughts, speak = ( response["command"], response["thoughts"], - response["speak"] + response["speak"], ) return PluginAction(command, speak, thoughts) diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 3f21b828d..6116deecd 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat): context = context[:2000] input_values = {"context": context, "question": self.current_user_input} except NoIndexException: - raise ValueError("you have no default knowledge store, please execute python knowledge_init.py") + raise ValueError( + "you have no default knowledge store, please execute python knowledge_init.py" + ) return input_values def do_with_prompt_response(self, prompt_response): diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index 55f525988..ac72fc637 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -5,7 +5,6 @@ import sys from dotenv import load_dotenv - if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): print("Setting random seed to 42") random.seed(42) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index beab61d4a..1e3a4dcb3 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -87,7 +87,10 @@ class ModelWorker: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" except Exception as e: - ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0} + ret = { + "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + "error_code": 0, + } yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index e76865550..761a239e7 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -667,8 +667,8 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") - - # init config + + # init config cfg = Config() load_native_plugins(cfg) @@ -682,7 +682,7 @@ if __name__ == "__main__": "pilot.commands.built_in.audio_text", "pilot.commands.built_in.image_gen", ] - # exclude commands + # exclude commands command_categories = [ x for x in command_categories if x not in cfg.disabled_command_categories ] diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 5f6d9526d..60046d0cd 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, + ) return loader.load_and_split(textsplitter) @register diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 66b0963d9..87ad9d1cf 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding): # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, ) return loader.load_and_split(textsplitter) diff --git a/pilot/source_embedding/ppt_embedding.py b/pilot/source_embedding/ppt_embedding.py index 869e92395..583b29ed1 100644 --- a/pilot/source_embedding/ppt_embedding.py +++ b/pilot/source_embedding/ppt_embedding.py @@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding): def read(self): """Load from ppt path.""" loader = UnstructuredPowerPointLoader(self.file_path) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, + ) return loader.load_and_split(textsplitter) @register diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 84fbf1550..5e551514b 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -78,7 +78,7 @@ class DBSummaryClient: model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) - table_docs =knowledge_embedding_client.similar_search(query, topk) + table_docs = knowledge_embedding_client.similar_search(query, topk) ans = [d.page_content for d in table_docs] return ans @@ -147,8 +147,6 @@ class DBSummaryClient: logger.info("init db profile success...") - - def _get_llm_response(query, db_input, dbsummary): chat_param = { "temperature": 0.7, diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py index 4a578fe2c..08a01c0fc 100644 --- a/pilot/summary/mysql_db_summary.py +++ b/pilot/summary/mysql_db_summary.py @@ -43,15 +43,14 @@ CFG = Config() # "tps": 50 # } + class MysqlSummary(DBSummary): """Get mysql summary template.""" def __init__(self, name): self.name = name self.type = "MYSQL" - self.summery = ( - """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" - ) + self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.tables = {} self.tables_info = [] self.vector_tables_info = [] @@ -92,9 +91,12 @@ class MysqlSummary(DBSummary): self.tables[table_name] = table_summary.get_columns() self.table_columns_info.append(table_summary.get_columns()) # self.table_columns_json.append(table_summary.get_summary_json()) - table_profile = "table name:{table_name},table description:{table_comment}".format( - table_name=table_name, table_comment=self.db.get_show_create_table(table_name) + table_profile = ( + "table name:{table_name},table description:{table_comment}".format( + table_name=table_name, + table_comment=self.db.get_show_create_table(table_name), ) + ) self.table_columns_json.append(table_profile) # self.tables_info.append(table_summary.get_summery()) @@ -108,7 +110,11 @@ class MysqlSummary(DBSummary): def get_db_summery(self): return self.summery.format( - name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 + name=self.name, + type=self.type, + tables=";".join(self.vector_tables_info), + qps=1000, + tps=1000, ) def get_table_summary(self): @@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary): self.indexes_info.append(index_summary.get_summery()) self.json_summery = self.json_summery_template.format( - name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 + name=name, + comment=comment_map[name], + fields=self.fields_info, + indexes=self.indexes_info, + size_in_bytes=1000, + rows=1000, ) def get_summery(self): @@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary): self.bind_fields = index[1] def get_summery(self): - return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) + return self.summery_template.format( + name=self.name, bind_fields=self.bind_fields + ) if __name__ == "__main__":