pylint: multi model for gp4all (#138)

This commit is contained in:
csunny 2023-06-14 10:17:53 +08:00
parent 3927c26dea
commit ff6d3a7035
25 changed files with 154 additions and 111 deletions

View File

@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config):
print("load_native_plugins")
### TODO 默认拉主分支,后续拉发布版本
branch_name = cfg.plugins_git_branch
native_plugin_repo ="DB-GPT-Plugins"
native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name),
headers={'Authorization': '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'})
response = requests.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S')
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = (
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
)
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)

View File

@ -66,7 +66,6 @@ class Database:
self._sample_rows_in_table_info = set()
self._indexes_in_table_info = indexes_in_table_info
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
@ -399,7 +398,6 @@ class Database:
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()

View File

@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
current_directory = os.getcwd()

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config()
class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector"""
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(cls, db_path: str,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
def from_uri_db(
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config()
class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData
Args:
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(cls, db_path: str,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
def from_uri_db(
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData
Args:
@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase):
dialect: str = "mssql"
driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData
Args:
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -2,8 +2,10 @@
# -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase):
"""OracleConnector"""
type: str = "ORACLE"
driver: str = "oracle"

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector"""

View File

@ -57,18 +57,19 @@ CFG = Config()
if __name__ == "__main__":
def __extract_json(s):
i = s.index('{')
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == '}':
count -= 1
elif c == '{':
count += 1
if count == 0:
break
assert (count == 0) # 检查是否找到最后一个'}'
return s[i:j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))
def __extract_json(s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))

View File

@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
):
"""Create engine from database URI."""
self._engine = engine
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
CFG.LOCAL_DB_PORT)
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
if cls.dialect:
db_url = cls.dialect + "+" + db_url
if db_name:
@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect):
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect):
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"

View File

@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
}
)
# Move the last user's information to the end
# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
text = ""
for line in res.iter_lines():
if line:
json_data = line.split(b': ', 1)[1]
json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)

View File

@ -121,17 +121,17 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(slef, s):
i = s.index('{')
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == '}':
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == '{':
elif c == "{":
count += 1
if count == 0:
break
assert (count == 0) # 检查是否找到最后一个'}'
return s[i:j + 1]
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""

View File

@ -134,7 +134,6 @@ class BaseChat(ABC):
return payload
def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base()
@ -189,19 +188,19 @@ class BaseChat(ABC):
)
)
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (

View File

@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str:
return super().generate(p)

View File

@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
command, thoughts, speak = (
response["command"],
response["thoughts"],
response["speak"]
response["speak"],
)
return PluginAction(command, speak, thoughts)

View File

@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py")
raise ValueError(
"you have no default knowledge store, please execute python knowledge_init.py"
)
return input_values
def do_with_prompt_response(self, prompt_response):

View File

@ -5,7 +5,6 @@ import sys
from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42")
random.seed(42)

View File

@ -87,7 +87,10 @@ class ModelWorker:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0}
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):

View File

@ -667,8 +667,8 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
# init config
# init config
cfg = Config()
load_native_plugins(cfg)
@ -682,7 +682,7 @@ if __name__ == "__main__":
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]

View File

@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self):
"""Load from markdown path."""
loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter)
@register

View File

@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter)

View File

@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
def read(self):
"""Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter)
@register

View File

@ -78,7 +78,7 @@ class DBSummaryClient:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_docs =knowledge_embedding_client.similar_search(query, topk)
table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
return ans
@ -147,8 +147,6 @@ class DBSummaryClient:
logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,

View File

@ -43,15 +43,14 @@ CFG = Config()
# "tps": 50
# }
class MysqlSummary(DBSummary):
"""Get mysql summary template."""
def __init__(self, name):
self.name = name
self.type = "MYSQL"
self.summery = (
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
)
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {}
self.tables_info = []
self.vector_tables_info = []
@ -92,9 +91,12 @@ class MysqlSummary(DBSummary):
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format(
table_name=table_name, table_comment=self.db.get_show_create_table(table_name)
table_profile = (
"table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery())
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
def get_db_summery(self):
return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000
name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
)
def get_table_summary(self):
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000
name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
)
def get_summery(self):
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
self.bind_fields = index[1]
def get_summery(self):
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields)
return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)
if __name__ == "__main__":