pylint: multi model for gp4all (#138)

This commit is contained in:
csunny 2023-06-14 10:17:53 +08:00
parent 3927c26dea
commit ff6d3a7035
25 changed files with 154 additions and 111 deletions

View File

@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config):
print("load_native_plugins") print("load_native_plugins")
### TODO 默认拉主分支,后续拉发布版本 ### TODO 默认拉主分支,后续拉发布版本
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
native_plugin_repo ="DB-GPT-Plugins" native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name), response = requests.get(
headers={'Authorization': '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'}) url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S') time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = (
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
)
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
f.write(response.content) f.write(response.content)

View File

@ -66,7 +66,6 @@ class Database:
self._sample_rows_in_table_info = set() self._sample_rows_in_table_info = set()
self._indexes_in_table_info = indexes_in_table_info self._indexes_in_table_info = indexes_in_table_info
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
@ -399,7 +398,6 @@ class Database:
ans = cursor.fetchall() ans = cursor.fetchall()
return ans[0][1] return ans[0][1]
def get_fields(self, table_name): def get_fields(self, table_name):
"""Get column fields about specified table.""" """Get column fields about specified table."""
session = self._db_sessions() session = self._db_sessions()

View File

@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data") DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts") FONT_DIR = os.path.join(PILOT_PATH, "fonts")
current_directory = os.getcwd() current_directory = os.getcwd()

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class ClickHouseConnector(RDBMSDatabase): class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector""" """ClickHouseConnector"""
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_PATH, cls,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class DuckDbConnect(RDBMSDatabase): class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData """Connect Duckdb Database fetch MetaData
Args: Args:
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_PATH, cls,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase): class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData """Connect MSSQL Database fetch MetaData
Args: Args:
@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase):
dialect: str = "mssql" dialect: str = "mssql"
driver: str = "pyodbc" driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"] default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase): class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData """Connect MySQL Database fetch MetaData
Args: Args:
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
driver: str = "pymysql" driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -2,8 +2,10 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase): class OracleConnector(RDBMSDatabase):
"""OracleConnector""" """OracleConnector"""
type: str = "ORACLE" type: str = "ORACLE"
driver: str = "oracle" driver: str = "oracle"

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase): class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector""" """PostgresConnector is a class which Connector"""

View File

@ -57,18 +57,19 @@ CFG = Config()
if __name__ == "__main__": if __name__ == "__main__":
def __extract_json(s):
i = s.index('{')
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == '}':
count -= 1
elif c == '{':
count += 1
if count == 0:
break
assert (count == 0) # 检查是否找到最后一个'}'
return s[i:j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" def __extract_json(s):
print(__extract_json(ss)) i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))

View File

@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
def __init__( def __init__(
self, self,
engine, engine,
schema: Optional[str] = None, schema: Optional[str] = None,
metadata: Optional[MetaData] = None, metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
): ):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_HOST, cls,
CFG.LOCAL_DB_PORT, CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_PASSWORD, CFG.LOCAL_DB_USER,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls,
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str( host: str,
CFG.LOCAL_DB_PORT) port: int,
user: str,
pwd: str,
db_name: str = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
if cls.dialect: if cls.dialect:
db_url = cls.dialect + "+" + db_url db_url = cls.dialect + "+" + db_url
if db_name: if db_name:
@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect):
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase: ) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI.""" """Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {} _engine_args = engine_args or {}
@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect):
tbl tbl
for tbl in self._metadata.sorted_tables for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names) if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
] ]
tables = [] tables = []
@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine)) create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}" table_info = f"{create_table.rstrip()}"
has_extra_info = ( has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info self._indexes_in_table_info or self._sample_rows_in_table_info
) )
if has_extra_info: if has_extra_info:
table_info += "\n\n/*" table_info += "\n\n/*"

View File

@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
} }
) )
# Move the last user's information to the end # Move the last user's information to the end
temp_his = history[::-1] temp_his = history[::-1]
last_user_input = None last_user_input = None
for m in temp_his: for m in temp_his:
@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
text = "" text = ""
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
json_data = line.split(b': ', 1)[1] json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8") decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower(): if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data) obj = json.loads(json_data)

View File

@ -121,17 +121,17 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index('{') i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == '}': if c == "}":
count -= 1 count -= 1
elif c == '{': elif c == "{":
count += 1 count += 1
if count == 0: if count == 0:
break break
assert (count == 0) # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i:j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """

View File

@ -134,7 +134,6 @@ class BaseChat(ABC):
return payload return payload
def stream_call(self): def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = self.__call_base() payload = self.__call_base()
@ -189,19 +188,19 @@ class BaseChat(ABC):
) )
) )
# ### MOCK # ### MOCK
# ai_response_text = """{ # ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。", # "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。", # "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": { # "command": {
# "name": "histogram-executor", # "name": "histogram-executor",
# "args": { # "args": {
# "title": "订单城市分布柱状图", # "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# } # }
# } # }
# }""" # }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = ( prompt_define_response = (

View File

@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str: def generate(self, p) -> str:
return super().generate(p) return super().generate(p)

View File

@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
command, thoughts, speak = ( command, thoughts, speak = (
response["command"], response["command"],
response["thoughts"], response["thoughts"],
response["speak"] response["speak"],
) )
return PluginAction(command, speak, thoughts) return PluginAction(command, speak, thoughts)

View File

@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
context = context[:2000] context = context[:2000]
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
except NoIndexException: except NoIndexException:
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py") raise ValueError(
"you have no default knowledge store, please execute python knowledge_init.py"
)
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):

View File

@ -5,7 +5,6 @@ import sys
from dotenv import load_dotenv from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42") print("Setting random seed to 42")
random.seed(42) random.seed(42)

View File

@ -87,7 +87,10 @@ class ModelWorker:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
except Exception as e: except Exception as e:
ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0} ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt): def get_embeddings(self, prompt):

View File

@ -667,8 +667,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
# init config # init config
cfg = Config() cfg = Config()
load_native_plugins(cfg) load_native_plugins(cfg)
@ -682,7 +682,7 @@ if __name__ == "__main__":
"pilot.commands.built_in.audio_text", "pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen", "pilot.commands.built_in.image_gen",
] ]
# exclude commands # exclude commands
command_categories = [ command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories x for x in command_categories if x not in cfg.disabled_command_categories
] ]

View File

@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter( textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
) )
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from ppt path.""" """Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path) loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -78,7 +78,7 @@ class DBSummaryClient:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
table_docs =knowledge_embedding_client.similar_search(query, topk) table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs] ans = [d.page_content for d in table_docs]
return ans return ans
@ -147,8 +147,6 @@ class DBSummaryClient:
logger.info("init db profile success...") logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary): def _get_llm_response(query, db_input, dbsummary):
chat_param = { chat_param = {
"temperature": 0.7, "temperature": 0.7,

View File

@ -43,15 +43,14 @@ CFG = Config()
# "tps": 50 # "tps": 50
# } # }
class MysqlSummary(DBSummary): class MysqlSummary(DBSummary):
"""Get mysql summary template.""" """Get mysql summary template."""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.type = "MYSQL" self.type = "MYSQL"
self.summery = ( self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
)
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
self.vector_tables_info = [] self.vector_tables_info = []
@ -92,9 +91,12 @@ class MysqlSummary(DBSummary):
self.tables[table_name] = table_summary.get_columns() self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns()) self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json()) # self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format( table_profile = (
table_name=table_name, table_comment=self.db.get_show_create_table(table_name) "table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
) )
)
self.table_columns_json.append(table_profile) self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summery())
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
def get_db_summery(self): def get_db_summery(self):
return self.summery.format( return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
) )
def get_table_summary(self): def get_table_summary(self):
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
self.indexes_info.append(index_summary.get_summery()) self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format( self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
) )
def get_summery(self): def get_summery(self):
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
self.bind_fields = index[1] self.bind_fields = index[1]
def get_summery(self): def get_summery(self):
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)
if __name__ == "__main__": if __name__ == "__main__":