mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 20:52:33 +00:00
pylint: multi model for gp4all (#138)
This commit is contained in:
parent
3927c26dea
commit
ff6d3a7035
@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config):
|
|||||||
print("load_native_plugins")
|
print("load_native_plugins")
|
||||||
### TODO 默认拉主分支,后续拉发布版本
|
### TODO 默认拉主分支,后续拉发布版本
|
||||||
branch_name = cfg.plugins_git_branch
|
branch_name = cfg.plugins_git_branch
|
||||||
native_plugin_repo ="DB-GPT-Plugins"
|
native_plugin_repo = "DB-GPT-Plugins"
|
||||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||||
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
response = requests.get(
|
||||||
headers={'Authorization': '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'})
|
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||||
|
headers={"Authorization": "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"},
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
plugins_path_path = Path(PLUGINS_DIR)
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
|
||||||
for file in files:
|
for file in files:
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
file_name = (
|
||||||
|
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||||
|
)
|
||||||
print(file_name)
|
print(file_name)
|
||||||
with open(file_name, "wb") as f:
|
with open(file_name, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
@ -66,7 +66,6 @@ class Database:
|
|||||||
self._sample_rows_in_table_info = set()
|
self._sample_rows_in_table_info = set()
|
||||||
self._indexes_in_table_info = indexes_in_table_info
|
self._indexes_in_table_info = indexes_in_table_info
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(
|
def from_uri(
|
||||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
@ -399,7 +398,6 @@ class Database:
|
|||||||
ans = cursor.fetchall()
|
ans = cursor.fetchall()
|
||||||
return ans[0][1]
|
return ans[0][1]
|
||||||
|
|
||||||
|
|
||||||
def get_fields(self, table_name):
|
def get_fields(self, table_name):
|
||||||
"""Get column fields about specified table."""
|
"""Get column fields about specified table."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
|
@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs")
|
|||||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||||
|
|
||||||
current_directory = os.getcwd()
|
current_directory = os.getcwd()
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ClickHouseConnector(RDBMSDatabase):
|
class ClickHouseConnector(RDBMSDatabase):
|
||||||
"""ClickHouseConnector"""
|
"""ClickHouseConnector"""
|
||||||
|
|
||||||
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
|
|||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls) -> RDBMSDatabase:
|
def from_config(cls) -> RDBMSDatabase:
|
||||||
"""
|
"""
|
||||||
Todo password encryption
|
Todo password encryption
|
||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
return cls.from_uri_db(cls,
|
return cls.from_uri_db(
|
||||||
CFG.LOCAL_DB_PATH,
|
cls,
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
CFG.LOCAL_DB_PATH,
|
||||||
|
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri_db(cls, db_path: str,
|
def from_uri_db(
|
||||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> RDBMSDatabase:
|
||||||
db_url: str = cls.connect_driver + "://" + db_path
|
db_url: str = cls.connect_driver + "://" + db_path
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DuckDbConnect(RDBMSDatabase):
|
class DuckDbConnect(RDBMSDatabase):
|
||||||
"""Connect Duckdb Database fetch MetaData
|
"""Connect Duckdb Database fetch MetaData
|
||||||
Args:
|
Args:
|
||||||
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
|
|||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls) -> RDBMSDatabase:
|
def from_config(cls) -> RDBMSDatabase:
|
||||||
"""
|
"""
|
||||||
Todo password encryption
|
Todo password encryption
|
||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
return cls.from_uri_db(cls,
|
return cls.from_uri_db(
|
||||||
CFG.LOCAL_DB_PATH,
|
cls,
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
CFG.LOCAL_DB_PATH,
|
||||||
|
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri_db(cls, db_path: str,
|
def from_uri_db(
|
||||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> RDBMSDatabase:
|
||||||
db_url: str = cls.connect_driver + "://" + db_path
|
db_url: str = cls.connect_driver + "://" + db_path
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
@ -5,9 +5,6 @@ from typing import Optional, Any
|
|||||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MSSQLConnect(RDBMSDatabase):
|
class MSSQLConnect(RDBMSDatabase):
|
||||||
"""Connect MSSQL Database fetch MetaData
|
"""Connect MSSQL Database fetch MetaData
|
||||||
Args:
|
Args:
|
||||||
@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase):
|
|||||||
dialect: str = "mssql"
|
dialect: str = "mssql"
|
||||||
driver: str = "pyodbc"
|
driver: str = "pyodbc"
|
||||||
|
|
||||||
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]
|
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,9 +5,6 @@ from typing import Optional, Any
|
|||||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MySQLConnect(RDBMSDatabase):
|
class MySQLConnect(RDBMSDatabase):
|
||||||
"""Connect MySQL Database fetch MetaData
|
"""Connect MySQL Database fetch MetaData
|
||||||
Args:
|
Args:
|
||||||
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
|
|||||||
driver: str = "pymysql"
|
driver: str = "pymysql"
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
class OracleConnector(RDBMSDatabase):
|
class OracleConnector(RDBMSDatabase):
|
||||||
"""OracleConnector"""
|
"""OracleConnector"""
|
||||||
|
|
||||||
type: str = "ORACLE"
|
type: str = "ORACLE"
|
||||||
|
|
||||||
driver: str = "oracle"
|
driver: str = "oracle"
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
class PostgresConnector(RDBMSDatabase):
|
class PostgresConnector(RDBMSDatabase):
|
||||||
"""PostgresConnector is a class which Connector"""
|
"""PostgresConnector is a class which Connector"""
|
||||||
|
|
||||||
|
@ -57,18 +57,19 @@ CFG = Config()
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
def __extract_json(s):
|
|
||||||
i = s.index('{')
|
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
|
||||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
|
||||||
if c == '}':
|
|
||||||
count -= 1
|
|
||||||
elif c == '{':
|
|
||||||
count += 1
|
|
||||||
if count == 0:
|
|
||||||
break
|
|
||||||
assert (count == 0) # 检查是否找到最后一个'}'
|
|
||||||
return s[i:j + 1]
|
|
||||||
|
|
||||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
def __extract_json(s):
|
||||||
print(__extract_json(ss))
|
i = s.index("{")
|
||||||
|
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
|
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||||
|
if c == "}":
|
||||||
|
count -= 1
|
||||||
|
elif c == "{":
|
||||||
|
count += 1
|
||||||
|
if count == 0:
|
||||||
|
break
|
||||||
|
assert count == 0 # 检查是否找到最后一个'}'
|
||||||
|
return s[i : j + 1]
|
||||||
|
|
||||||
|
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||||
|
print(__extract_json(ss))
|
||||||
|
@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
engine,
|
engine,
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
metadata: Optional[MetaData] = None,
|
metadata: Optional[MetaData] = None,
|
||||||
ignore_tables: Optional[List[str]] = None,
|
ignore_tables: Optional[List[str]] = None,
|
||||||
include_tables: Optional[List[str]] = None,
|
include_tables: Optional[List[str]] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""Create engine from database URI."""
|
"""Create engine from database URI."""
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
Todo password encryption
|
Todo password encryption
|
||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
return cls.from_uri_db(cls,
|
return cls.from_uri_db(
|
||||||
CFG.LOCAL_DB_HOST,
|
cls,
|
||||||
CFG.LOCAL_DB_PORT,
|
CFG.LOCAL_DB_HOST,
|
||||||
CFG.LOCAL_DB_USER,
|
CFG.LOCAL_DB_PORT,
|
||||||
CFG.LOCAL_DB_PASSWORD,
|
CFG.LOCAL_DB_USER,
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
CFG.LOCAL_DB_PASSWORD,
|
||||||
|
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
|
def from_uri_db(
|
||||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
cls,
|
||||||
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
|
host: str,
|
||||||
CFG.LOCAL_DB_PORT)
|
port: int,
|
||||||
|
user: str,
|
||||||
|
pwd: str,
|
||||||
|
db_name: str = None,
|
||||||
|
engine_args: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RDBMSDatabase:
|
||||||
|
db_url: str = (
|
||||||
|
cls.connect_driver
|
||||||
|
+ "://"
|
||||||
|
+ CFG.LOCAL_DB_USER
|
||||||
|
+ ":"
|
||||||
|
+ CFG.LOCAL_DB_PASSWORD
|
||||||
|
+ "@"
|
||||||
|
+ CFG.LOCAL_DB_HOST
|
||||||
|
+ ":"
|
||||||
|
+ str(CFG.LOCAL_DB_PORT)
|
||||||
|
)
|
||||||
if cls.dialect:
|
if cls.dialect:
|
||||||
db_url = cls.dialect + "+" + db_url
|
db_url = cls.dialect + "+" + db_url
|
||||||
if db_name:
|
if db_name:
|
||||||
@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(
|
def from_uri(
|
||||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
"""Construct a SQLAlchemy engine from URI."""
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
_engine_args = engine_args or {}
|
_engine_args = engine_args or {}
|
||||||
@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
tbl
|
tbl
|
||||||
for tbl in self._metadata.sorted_tables
|
for tbl in self._metadata.sorted_tables
|
||||||
if tbl.name in set(all_table_names)
|
if tbl.name in set(all_table_names)
|
||||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||||
]
|
]
|
||||||
|
|
||||||
tables = []
|
tables = []
|
||||||
@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
create_table = str(CreateTable(table).compile(self._engine))
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
table_info = f"{create_table.rstrip()}"
|
table_info = f"{create_table.rstrip()}"
|
||||||
has_extra_info = (
|
has_extra_info = (
|
||||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||||
)
|
)
|
||||||
if has_extra_info:
|
if has_extra_info:
|
||||||
table_info += "\n\n/*"
|
table_info += "\n\n/*"
|
||||||
|
@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move the last user's information to the end
|
# Move the last user's information to the end
|
||||||
temp_his = history[::-1]
|
temp_his = history[::-1]
|
||||||
last_user_input = None
|
last_user_input = None
|
||||||
for m in temp_his:
|
for m in temp_his:
|
||||||
@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
text = ""
|
text = ""
|
||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
json_data = line.split(b': ', 1)[1]
|
json_data = line.split(b": ", 1)[1]
|
||||||
decoded_line = json_data.decode("utf-8")
|
decoded_line = json_data.decode("utf-8")
|
||||||
if decoded_line.lower() != "[DONE]".lower():
|
if decoded_line.lower() != "[DONE]".lower():
|
||||||
obj = json.loads(json_data)
|
obj = json.loads(json_data)
|
||||||
|
@ -121,17 +121,17 @@ class BaseOutputParser(ABC):
|
|||||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
def __extract_json(slef, s):
|
def __extract_json(slef, s):
|
||||||
i = s.index('{')
|
i = s.index("{")
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||||
if c == '}':
|
if c == "}":
|
||||||
count -= 1
|
count -= 1
|
||||||
elif c == '{':
|
elif c == "{":
|
||||||
count += 1
|
count += 1
|
||||||
if count == 0:
|
if count == 0:
|
||||||
break
|
break
|
||||||
assert (count == 0) # 检查是否找到最后一个'}'
|
assert count == 0 # 检查是否找到最后一个'}'
|
||||||
return s[i:j + 1]
|
return s[i : j + 1]
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
|
@ -134,7 +134,6 @@ class BaseChat(ABC):
|
|||||||
return payload
|
return payload
|
||||||
|
|
||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
|
|
||||||
# TODO Retry when server connection error
|
# TODO Retry when server connection error
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
@ -189,19 +188,19 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# ### MOCK
|
# ### MOCK
|
||||||
# ai_response_text = """{
|
# ai_response_text = """{
|
||||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||||
# "command": {
|
# "command": {
|
||||||
# "name": "histogram-executor",
|
# "name": "histogram-executor",
|
||||||
# "args": {
|
# "args": {
|
||||||
# "title": "订单城市分布柱状图",
|
# "title": "订单城市分布柱状图",
|
||||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
# }"""
|
# }"""
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = (
|
prompt_define_response = (
|
||||||
|
@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
|
|||||||
def __list_to_prompt_str(self, list: List) -> str:
|
def __list_to_prompt_str(self, list: List) -> str:
|
||||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||||
|
|
||||||
|
|
||||||
def generate(self, p) -> str:
|
def generate(self, p) -> str:
|
||||||
return super().generate(p)
|
return super().generate(p)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
|
|||||||
command, thoughts, speak = (
|
command, thoughts, speak = (
|
||||||
response["command"],
|
response["command"],
|
||||||
response["thoughts"],
|
response["thoughts"],
|
||||||
response["speak"]
|
response["speak"],
|
||||||
)
|
)
|
||||||
return PluginAction(command, speak, thoughts)
|
return PluginAction(command, speak, thoughts)
|
||||||
|
|
||||||
|
@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
|
|||||||
context = context[:2000]
|
context = context[:2000]
|
||||||
input_values = {"context": context, "question": self.current_user_input}
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
except NoIndexException:
|
except NoIndexException:
|
||||||
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py")
|
raise ValueError(
|
||||||
|
"you have no default knowledge store, please execute python knowledge_init.py"
|
||||||
|
)
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
|
@ -5,7 +5,6 @@ import sys
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||||
print("Setting random seed to 42")
|
print("Setting random seed to 42")
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
|
@ -87,7 +87,10 @@ class ModelWorker:
|
|||||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0}
|
ret = {
|
||||||
|
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
def get_embeddings(self, prompt):
|
def get_embeddings(self, prompt):
|
||||||
|
@ -667,8 +667,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
# init config
|
# init config
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
load_native_plugins(cfg)
|
load_native_plugins(cfg)
|
||||||
@ -682,7 +682,7 @@ if __name__ == "__main__":
|
|||||||
"pilot.commands.built_in.audio_text",
|
"pilot.commands.built_in.audio_text",
|
||||||
"pilot.commands.built_in.image_gen",
|
"pilot.commands.built_in.image_gen",
|
||||||
]
|
]
|
||||||
# exclude commands
|
# exclude commands
|
||||||
command_categories = [
|
command_categories = [
|
||||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||||
]
|
]
|
||||||
|
@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
|
|||||||
def read(self):
|
def read(self):
|
||||||
"""Load from markdown path."""
|
"""Load from markdown path."""
|
||||||
loader = EncodeTextLoader(self.file_path)
|
loader = EncodeTextLoader(self.file_path)
|
||||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=200,
|
||||||
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
|
@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
|
|||||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
# )
|
# )
|
||||||
textsplitter = SpacyTextSplitter(
|
textsplitter = SpacyTextSplitter(
|
||||||
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=200,
|
||||||
)
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
|
|||||||
def read(self):
|
def read(self):
|
||||||
"""Load from ppt path."""
|
"""Load from ppt path."""
|
||||||
loader = UnstructuredPowerPointLoader(self.file_path)
|
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=200,
|
||||||
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
|
@ -78,7 +78,7 @@ class DBSummaryClient:
|
|||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
table_docs =knowledge_embedding_client.similar_search(query, topk)
|
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||||
ans = [d.page_content for d in table_docs]
|
ans = [d.page_content for d in table_docs]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -147,8 +147,6 @@ class DBSummaryClient:
|
|||||||
logger.info("init db profile success...")
|
logger.info("init db profile success...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _get_llm_response(query, db_input, dbsummary):
|
def _get_llm_response(query, db_input, dbsummary):
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
|
@ -43,15 +43,14 @@ CFG = Config()
|
|||||||
# "tps": 50
|
# "tps": 50
|
||||||
# }
|
# }
|
||||||
|
|
||||||
|
|
||||||
class MysqlSummary(DBSummary):
|
class MysqlSummary(DBSummary):
|
||||||
"""Get mysql summary template."""
|
"""Get mysql summary template."""
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.type = "MYSQL"
|
self.type = "MYSQL"
|
||||||
self.summery = (
|
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||||
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
|
||||||
)
|
|
||||||
self.tables = {}
|
self.tables = {}
|
||||||
self.tables_info = []
|
self.tables_info = []
|
||||||
self.vector_tables_info = []
|
self.vector_tables_info = []
|
||||||
@ -92,9 +91,12 @@ class MysqlSummary(DBSummary):
|
|||||||
self.tables[table_name] = table_summary.get_columns()
|
self.tables[table_name] = table_summary.get_columns()
|
||||||
self.table_columns_info.append(table_summary.get_columns())
|
self.table_columns_info.append(table_summary.get_columns())
|
||||||
# self.table_columns_json.append(table_summary.get_summary_json())
|
# self.table_columns_json.append(table_summary.get_summary_json())
|
||||||
table_profile = "table name:{table_name},table description:{table_comment}".format(
|
table_profile = (
|
||||||
table_name=table_name, table_comment=self.db.get_show_create_table(table_name)
|
"table name:{table_name},table description:{table_comment}".format(
|
||||||
|
table_name=table_name,
|
||||||
|
table_comment=self.db.get_show_create_table(table_name),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.table_columns_json.append(table_profile)
|
self.table_columns_json.append(table_profile)
|
||||||
# self.tables_info.append(table_summary.get_summery())
|
# self.tables_info.append(table_summary.get_summery())
|
||||||
|
|
||||||
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
|
|||||||
|
|
||||||
def get_db_summery(self):
|
def get_db_summery(self):
|
||||||
return self.summery.format(
|
return self.summery.format(
|
||||||
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000
|
name=self.name,
|
||||||
|
type=self.type,
|
||||||
|
tables=";".join(self.vector_tables_info),
|
||||||
|
qps=1000,
|
||||||
|
tps=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_table_summary(self):
|
def get_table_summary(self):
|
||||||
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
|
|||||||
self.indexes_info.append(index_summary.get_summery())
|
self.indexes_info.append(index_summary.get_summery())
|
||||||
|
|
||||||
self.json_summery = self.json_summery_template.format(
|
self.json_summery = self.json_summery_template.format(
|
||||||
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000
|
name=name,
|
||||||
|
comment=comment_map[name],
|
||||||
|
fields=self.fields_info,
|
||||||
|
indexes=self.indexes_info,
|
||||||
|
size_in_bytes=1000,
|
||||||
|
rows=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_summery(self):
|
def get_summery(self):
|
||||||
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
|
|||||||
self.bind_fields = index[1]
|
self.bind_fields = index[1]
|
||||||
|
|
||||||
def get_summery(self):
|
def get_summery(self):
|
||||||
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields)
|
return self.summery_template.format(
|
||||||
|
name=self.name, bind_fields=self.bind_fields
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user