mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
pylint: multi model for gp4all (#138)
This commit is contained in:
parent
3927c26dea
commit
ff6d3a7035
@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config):
|
||||
print("load_native_plugins")
|
||||
### TODO 默认拉主分支,后续拉发布版本
|
||||
branch_name = cfg.plugins_git_branch
|
||||
native_plugin_repo ="DB-GPT-Plugins"
|
||||
native_plugin_repo = "DB-GPT-Plugins"
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={'Authorization': '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'})
|
||||
response = requests.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = (
|
||||
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
)
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
@ -66,7 +66,6 @@ class Database:
|
||||
self._sample_rows_in_table_info = set()
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
@ -399,7 +398,6 @@ class Database:
|
||||
ans = cursor.fetchall()
|
||||
return ans[0][1]
|
||||
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
|
@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
|
||||
current_directory = os.getcwd()
|
||||
|
||||
|
@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ClickHouseConnector(RDBMSDatabase):
|
||||
"""ClickHouseConnector"""
|
||||
|
||||
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
def from_uri_db(
|
||||
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
def from_uri_db(
|
||||
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
@ -5,9 +5,6 @@ from typing import Optional, Any
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase):
|
||||
dialect: str = "mssql"
|
||||
driver: str = "pyodbc"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]
|
||||
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
||||
|
@ -5,9 +5,6 @@ from typing import Optional, Any
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
"""Connect MySQL Database fetch MetaData
|
||||
Args:
|
||||
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
|
||||
driver: str = "pymysql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
|
@ -2,8 +2,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class OracleConnector(RDBMSDatabase):
|
||||
"""OracleConnector"""
|
||||
|
||||
type: str = "ORACLE"
|
||||
|
||||
driver: str = "oracle"
|
||||
|
@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class PostgresConnector(RDBMSDatabase):
|
||||
"""PostgresConnector is a class which Connector"""
|
||||
|
||||
|
@ -57,18 +57,19 @@ CFG = Config()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def __extract_json(s):
|
||||
i = s.index('{')
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
||||
def __extract_json(s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
||||
|
@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
|
||||
CFG.LOCAL_DB_PORT)
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
cls.connect_driver
|
||||
+ "://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
)
|
||||
if cls.dialect:
|
||||
db_url = cls.dialect + "+" + db_url
|
||||
if db_name:
|
||||
@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
|
@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
}
|
||||
)
|
||||
|
||||
# Move the last user's information to the end
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
json_data = line.split(b': ', 1)[1]
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
|
@ -121,17 +121,17 @@ class BaseOutputParser(ABC):
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def __extract_json(slef, s):
|
||||
i = s.index('{')
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
|
@ -134,7 +134,6 @@ class BaseChat(ABC):
|
||||
return payload
|
||||
|
||||
def stream_call(self):
|
||||
|
||||
# TODO Retry when server connection error
|
||||
payload = self.__call_base()
|
||||
|
||||
@ -189,19 +188,19 @@ class BaseChat(ABC):
|
||||
)
|
||||
)
|
||||
|
||||
# ### MOCK
|
||||
# ai_response_text = """{
|
||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||
# "command": {
|
||||
# "name": "histogram-executor",
|
||||
# "args": {
|
||||
# "title": "订单城市分布柱状图",
|
||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||
# }
|
||||
# }
|
||||
# }"""
|
||||
# ### MOCK
|
||||
# ai_response_text = """{
|
||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||
# "command": {
|
||||
# "name": "histogram-executor",
|
||||
# "args": {
|
||||
# "title": "订单城市分布柱状图",
|
||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||
# }
|
||||
# }
|
||||
# }"""
|
||||
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = (
|
||||
|
@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
|
||||
def generate(self, p) -> str:
|
||||
return super().generate(p)
|
||||
|
||||
|
@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
|
||||
command, thoughts, speak = (
|
||||
response["command"],
|
||||
response["thoughts"],
|
||||
response["speak"]
|
||||
response["speak"],
|
||||
)
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
|
@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
context = context[:2000]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
except NoIndexException:
|
||||
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py")
|
||||
raise ValueError(
|
||||
"you have no default knowledge store, please execute python knowledge_init.py"
|
||||
)
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
|
@ -5,7 +5,6 @@ import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
@ -87,7 +87,10 @@ class ModelWorker:
|
||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except Exception as e:
|
||||
ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0}
|
||||
ret = {
|
||||
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, prompt):
|
||||
|
@ -667,8 +667,8 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
|
||||
load_native_plugins(cfg)
|
||||
@ -682,7 +682,7 @@ if __name__ == "__main__":
|
||||
"pilot.commands.built_in.audio_text",
|
||||
"pilot.commands.built_in.image_gen",
|
||||
]
|
||||
# exclude commands
|
||||
# exclude commands
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
|
@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
|
@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
# )
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
|
@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from ppt path."""
|
||||
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
|
@ -78,7 +78,7 @@ class DBSummaryClient:
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_docs =knowledge_embedding_client.similar_search(query, topk)
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
ans = [d.page_content for d in table_docs]
|
||||
return ans
|
||||
|
||||
@ -147,8 +147,6 @@ class DBSummaryClient:
|
||||
logger.info("init db profile success...")
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_llm_response(query, db_input, dbsummary):
|
||||
chat_param = {
|
||||
"temperature": 0.7,
|
||||
|
@ -43,15 +43,14 @@ CFG = Config()
|
||||
# "tps": 50
|
||||
# }
|
||||
|
||||
|
||||
class MysqlSummary(DBSummary):
|
||||
"""Get mysql summary template."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.type = "MYSQL"
|
||||
self.summery = (
|
||||
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
)
|
||||
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
self.tables = {}
|
||||
self.tables_info = []
|
||||
self.vector_tables_info = []
|
||||
@ -92,9 +91,12 @@ class MysqlSummary(DBSummary):
|
||||
self.tables[table_name] = table_summary.get_columns()
|
||||
self.table_columns_info.append(table_summary.get_columns())
|
||||
# self.table_columns_json.append(table_summary.get_summary_json())
|
||||
table_profile = "table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_name, table_comment=self.db.get_show_create_table(table_name)
|
||||
table_profile = (
|
||||
"table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_name,
|
||||
table_comment=self.db.get_show_create_table(table_name),
|
||||
)
|
||||
)
|
||||
self.table_columns_json.append(table_profile)
|
||||
# self.tables_info.append(table_summary.get_summery())
|
||||
|
||||
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
|
||||
|
||||
def get_db_summery(self):
|
||||
return self.summery.format(
|
||||
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
tables=";".join(self.vector_tables_info),
|
||||
qps=1000,
|
||||
tps=1000,
|
||||
)
|
||||
|
||||
def get_table_summary(self):
|
||||
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
|
||||
self.indexes_info.append(index_summary.get_summery())
|
||||
|
||||
self.json_summery = self.json_summery_template.format(
|
||||
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000
|
||||
name=name,
|
||||
comment=comment_map[name],
|
||||
fields=self.fields_info,
|
||||
indexes=self.indexes_info,
|
||||
size_in_bytes=1000,
|
||||
rows=1000,
|
||||
)
|
||||
|
||||
def get_summery(self):
|
||||
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
|
||||
self.bind_fields = index[1]
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields)
|
||||
return self.summery_template.format(
|
||||
name=self.name, bind_fields=self.bind_fields
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user