mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
Merge branch 'ty_test' of https://github.com/csunny/DB-GPT into dbgpt_doc
This commit is contained in:
commit
05acabddf0
@ -7,6 +7,10 @@
|
||||
## For example, to disable coding related features, uncomment the next line
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
|
||||
#*******************************************************************#
|
||||
#** Webserver Port **#
|
||||
#*******************************************************************#
|
||||
WEB_SERVER_PORT=7860
|
||||
|
||||
#*******************************************************************#
|
||||
#*** LLM PROVIDER ***#
|
||||
@ -17,6 +21,7 @@
|
||||
#*******************************************************************#
|
||||
#** LLM MODELS **#
|
||||
#*******************************************************************#
|
||||
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
|
||||
LLM_MODEL=vicuna-13b
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
LIMIT_MODEL_CONCURRENCY=5
|
||||
@ -98,15 +103,20 @@ VECTOR_STORE_TYPE=Chroma
|
||||
#MILVUS_SECURE=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** WebServer Language Support **#
|
||||
#*******************************************************************#
|
||||
LANGUAGE=en
|
||||
#LANGUAGE=zh
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
# ** PROXY_SERVER
|
||||
# ** PROXY_SERVER (openai interface | chatGPT proxy service), use chatGPT as your LLM.
|
||||
# ** if your server can visit openai, please set PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||
# ** else if you have a chatgpt proxy server, you can set PROXY_SERVER_URL={your-proxy-serverip:port/xxx}
|
||||
#*******************************************************************#
|
||||
PROXY_API_KEY=
|
||||
PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
|
||||
PROXY_API_KEY={your-openai-sk}
|
||||
PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -46,7 +46,7 @@ MANIFEST
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
# Unit test_py / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
45
echarts.min.js
vendored
Normal file
45
echarts.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -2,7 +2,10 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import zipfile
|
||||
import requests
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
@ -12,6 +15,7 @@ import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
@ -69,6 +73,31 @@ def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_native_plugins(cfg: Config):
|
||||
print("load_native_plugins")
|
||||
### TODO 默认拉主分支,后续拉发布版本
|
||||
branch_name = cfg.plugins_git_branch
|
||||
native_plugin_repo ="DB-GPT-Plugins"
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={'Authorization': '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'})
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("文件已保存到本地")
|
||||
else:
|
||||
print("获取Release信息失败,状态码为:", response.status_code)
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
@ -83,7 +112,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
|
||||
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
|
||||
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
|
||||
@ -104,7 +133,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
|
||||
|
@ -64,59 +64,8 @@ class Database:
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
# self._all_tables = set(
|
||||
# self._inspector.get_table_names(schema=schema)
|
||||
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||
# )
|
||||
|
||||
# self._include_tables = set(include_tables) if include_tables else set()
|
||||
# if self._include_tables:
|
||||
# missing_tables = self._include_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"include_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||
# if self._ignore_tables:
|
||||
# missing_tables = self._ignore_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"ignore_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# usable_tables = self.get_usable_table_names()
|
||||
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||
|
||||
# if not isinstance(sample_rows_in_table_info, int):
|
||||
# raise TypeError("sample_rows_in_table_info must be an integer")
|
||||
#
|
||||
# self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
#
|
||||
# self._custom_table_info = custom_table_info
|
||||
# if self._custom_table_info:
|
||||
# if not isinstance(self._custom_table_info, dict):
|
||||
# raise TypeError(
|
||||
# "table_info must be a dictionary with table names as keys and the "
|
||||
# "desired table info as values"
|
||||
# )
|
||||
# # only keep the tables that are also present in the database
|
||||
# intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||
# self._custom_table_info = dict(
|
||||
# (table, self._custom_table_info[table])
|
||||
# for table in self._custom_table_info
|
||||
# if table in intersection
|
||||
# )
|
||||
|
||||
# self._metadata = metadata or MetaData()
|
||||
# # # including view support if view_support = true
|
||||
# self._metadata.reflect(
|
||||
# views=view_support,
|
||||
# bind=self._engine,
|
||||
# only=list(self._usable_tables),
|
||||
# schema=self._schema,
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
|
@ -17,8 +17,9 @@ class Config(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
# Gradio language version: en, cn
|
||||
# Gradio language version: en, zh
|
||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||
|
||||
self.debug_mode = False
|
||||
self.skip_reprompt = False
|
||||
@ -36,7 +37,7 @@ class Config(metaclass=Singleton):
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
# This is a proxy server, just for test. we will remove this later.
|
||||
# This is a proxy server, just for test_py. we will remove this later.
|
||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
@ -87,10 +88,11 @@ class Config(metaclass=Singleton):
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
@ -112,6 +114,7 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
@ -13,8 +13,18 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
|
||||
# 获取当前工作目录
|
||||
current_directory = os.getcwd()
|
||||
print("当前工作目录:", current_directory)
|
||||
|
||||
# 设置当前工作目录
|
||||
new_directory = PILOT_PATH
|
||||
os.chdir(new_directory)
|
||||
print("新的工作目录:", os.getcwd())
|
||||
|
||||
DEVICE = (
|
||||
"cuda"
|
||||
|
@ -1,8 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
class ClickHouseConnector:
|
||||
CFG = Config()
|
||||
|
||||
class ClickHouseConnector(RDBMSDatabase):
|
||||
"""ClickHouseConnector"""
|
||||
|
||||
pass
|
||||
type: str = "DUCKDB"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
38
pilot/connections/rdbms/duckdb.py
Normal file
38
pilot/connections/rdbms/duckdb.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "DUCKDB"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class ElasticSearchConnector:
|
||||
"""ElasticSearchConnector"""
|
||||
|
||||
pass
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class MongoConnector:
|
||||
"""MongoConnector is a class which connect to mongo and chat with LLM"""
|
||||
|
||||
pass
|
23
pilot/connections/rdbms/mssql.py
Normal file
23
pilot/connections/rdbms/mssql.py
Normal file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MSSQL"
|
||||
dialect: str = "mssql"
|
||||
driver: str = "pyodbc"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]
|
||||
|
||||
|
@ -1,17 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
import pymysql
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||
"""Connect MySQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MySQL"
|
||||
connect_url = "mysql+pymysql://"
|
||||
dialect: str = "mysql"
|
||||
driver: str = "pymysql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class OracleConnector:
|
||||
class OracleConnector(RDBMSDatabase):
|
||||
"""OracleConnector"""
|
||||
type: str = "ORACLE"
|
||||
|
||||
pass
|
||||
driver: str = "oracle"
|
||||
|
||||
default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
class PostgresConnector(RDBMSDatabase):
|
||||
"""PostgresConnector is a class which Connector"""
|
||||
|
||||
class PostgresConnector:
|
||||
"""PostgresConnector is a class which Connector to chat with LLM"""
|
||||
type: str = "POSTGRESQL"
|
||||
driver: str = "postgresql"
|
||||
|
||||
pass
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
0
pilot/connections/rdbms/py_study/__init__.py
Normal file
0
pilot/connections/rdbms/py_study/__init__.py
Normal file
74
pilot/connections/rdbms/py_study/pd_study.py
Normal file
74
pilot/connections/rdbms/py_study/pd_study.py
Normal file
@ -0,0 +1,74 @@
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, pool
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
|
||||
CFG = Config()
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# # 创建连接池
|
||||
# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
|
||||
#
|
||||
# # 从连接池中获取连接
|
||||
#
|
||||
#
|
||||
# # 归还连接到连接池中
|
||||
#
|
||||
# # 执行SQL语句并将结果转化为DataFrame
|
||||
# query = "SELECT * FROM users"
|
||||
# df = pd.read_sql(query, engine.connect())
|
||||
# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
|
||||
# # 导出为HTML文件
|
||||
# with open('report.html', 'w') as f:
|
||||
# f.write(df.style.render())
|
||||
#
|
||||
# # # 设置中文字体
|
||||
# # font = FontProperties(fname='SimHei.ttf', size=14)
|
||||
# #
|
||||
# # colors = np.random.rand(df.shape[0])
|
||||
# # df.plot.scatter(x='city', y='user_name', c=colors)
|
||||
# # plt.show()
|
||||
#
|
||||
# # 查看DataFrame
|
||||
# print(df.head())
|
||||
#
|
||||
#
|
||||
# # 创建数据
|
||||
# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
# y_data = [820, 932, 901, 934, 1290, 1330, 1320]
|
||||
#
|
||||
# # 生成图表
|
||||
# bar = (
|
||||
# Bar()
|
||||
# .add_xaxis(x_data)
|
||||
# .add_yaxis("销售额", y_data)
|
||||
# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
|
||||
# )
|
||||
#
|
||||
# # 生成HTML文件
|
||||
# bar.render('report.html')
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def __extract_json(s):
|
||||
i = s.index('{')
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
@ -19,6 +19,9 @@ from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
from pilot.connections.base import BaseConnect
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
@ -32,16 +35,13 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -55,73 +55,33 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
self._db_sessions = Session
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
# self._all_tables = set(
|
||||
# self._inspector.get_table_names(schema=schema)
|
||||
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||
# )
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
# self._include_tables = set(include_tables) if include_tables else set()
|
||||
# if self._include_tables:
|
||||
# missing_tables = self._include_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"include_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||
# if self._ignore_tables:
|
||||
# missing_tables = self._ignore_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"ignore_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# usable_tables = self.get_usable_table_names()
|
||||
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||
|
||||
# if not isinstance(sample_rows_in_table_info, int):
|
||||
# raise TypeError("sample_rows_in_table_info must be an integer")
|
||||
#
|
||||
# self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
# self._indexes_in_table_info = indexes_in_table_info
|
||||
#
|
||||
# self._custom_table_info = custom_table_info
|
||||
# if self._custom_table_info:
|
||||
# if not isinstance(self._custom_table_info, dict):
|
||||
# raise TypeError(
|
||||
# "table_info must be a dictionary with table names as keys and the "
|
||||
# "desired table info as values"
|
||||
# )
|
||||
# # only keep the tables that are also present in the database
|
||||
# intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||
# self._custom_table_info = dict(
|
||||
# (table, self._custom_table_info[table])
|
||||
# for table in self._custom_table_info
|
||||
# if table in intersection
|
||||
# )
|
||||
|
||||
# self._metadata = metadata or MetaData()
|
||||
# # # including view support if view_support = true
|
||||
# self._metadata.reflect(
|
||||
# views=view_support,
|
||||
# bind=self._engine,
|
||||
# only=list(self._usable_tables),
|
||||
# schema=self._schema,
|
||||
# )
|
||||
@classmethod
|
||||
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
|
||||
CFG.LOCAL_DB_PORT)
|
||||
if cls.dialect:
|
||||
db_url = cls.dialect + "+" + db_url
|
||||
if db_name:
|
||||
db_url = db_url + "/" + db_name
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -207,7 +167,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -220,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
|
BIN
pilot/mock_datas/db-gpt-test.db
Normal file
BIN
pilot/mock_datas/db-gpt-test.db
Normal file
Binary file not shown.
@ -213,6 +213,6 @@ register_llm_model_adapters(FalconAdapater)
|
||||
register_llm_model_adapters(GorillaAdapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test, remove this later
|
||||
# just for test_py, remove this later
|
||||
register_llm_model_adapters(ProxyllmAdapter)
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
|
@ -66,6 +66,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
@ -75,14 +76,32 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
try:
|
||||
json_line = json.loads(decoded_line)
|
||||
print(json_line)
|
||||
text += json_line["choices"][0]["message"]["content"]
|
||||
yield text
|
||||
except Exception as e:
|
||||
text += decoded_line
|
||||
yield json.loads(text)["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
json_data = line.split(b': ', 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != '[DONE]'.lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj['choices'][0]['delta'].get('content') is not None:
|
||||
content = obj['choices'][0]['delta']['content']
|
||||
text += content
|
||||
yield text
|
||||
|
||||
# native result.
|
||||
# payloads = {
|
||||
# "model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
# "messages": history,
|
||||
# "temperature": params.get("temperature"),
|
||||
# "max_tokens": params.get("max_new_tokens"),
|
||||
# }
|
||||
#
|
||||
# res = requests.post(
|
||||
# CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
# )
|
||||
#
|
||||
# text = ""
|
||||
# line = res.content
|
||||
# if line:
|
||||
# decoded_line = line.decode("utf-8")
|
||||
# json_line = json.loads(decoded_line)
|
||||
# print(json_line)
|
||||
# text += json_line["choices"][0]["message"]["content"]
|
||||
# yield text
|
@ -95,7 +95,6 @@ class BaseOutputParser(ABC):
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
text = text.lower()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj["response"]
|
||||
@ -111,7 +110,9 @@ class BaseOutputParser(ABC):
|
||||
last_index = i
|
||||
ai_response = tmpResp[last_index]
|
||||
ai_response = ai_response.replace("assistant:", "")
|
||||
ai_response = ai_response.replace("\n", "")
|
||||
ai_response = ai_response.replace("Assistant:", "")
|
||||
ai_response = ai_response.replace("ASSISTANT:", "")
|
||||
ai_response = ai_response.replace("\n", " ")
|
||||
ai_response = ai_response.replace("\_", "_")
|
||||
ai_response = ai_response.replace("\*", "*")
|
||||
print("un_stream ai response:", ai_response)
|
||||
@ -119,6 +120,19 @@ class BaseOutputParser(ABC):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def __extract_json(slef, s):
|
||||
i = s.index('{')
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
@ -129,8 +143,8 @@ class BaseOutputParser(ABC):
|
||||
|
||||
"""
|
||||
cleaned_output = model_out_text.rstrip()
|
||||
# if "```json" in cleaned_output:
|
||||
# _, cleaned_output = cleaned_output.split("```json")
|
||||
if "```json" in cleaned_output:
|
||||
_, cleaned_output = cleaned_output.split("```json")
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
@ -142,18 +156,12 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing")
|
||||
json_pattern = r"{(.+?)}"
|
||||
m = re.search(json_pattern, cleaned_output)
|
||||
if m:
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", "")
|
||||
.replace("\\n", "")
|
||||
.replace("\\", "")
|
||||
.replace("\\", "")
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
|
@ -110,7 +110,7 @@ train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)
|
||||
|
||||
train_data = train_val["train"].map(generate_and_tokenize_prompt)
|
||||
|
||||
val_data = train_val["test"].map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test_py"].map(generate_and_tokenize_prompt)
|
||||
|
||||
# Training
|
||||
LORA_R = 8
|
||||
|
@ -70,10 +70,8 @@ class BaseChat(ABC):
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
### can configurable storage methods
|
||||
# self.memory = MemHistoryMemory(chat_session_id)
|
||||
self.memory = MemHistoryMemory(chat_session_id)
|
||||
|
||||
## TEST
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
@ -188,6 +186,21 @@ class BaseChat(ABC):
|
||||
response, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
|
||||
# ### MOCK
|
||||
# ai_response_text = """{
|
||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||
# "command": {
|
||||
# "name": "histogram-executor",
|
||||
# "args": {
|
||||
# "title": "订单城市分布柱状图",
|
||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||
# }
|
||||
# }
|
||||
# }"""
|
||||
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
|
@ -78,11 +78,8 @@ class ChatWithPlugin(BaseChat):
|
||||
super().chat_show()
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
if list:
|
||||
separator = "\n"
|
||||
return separator.join(list)
|
||||
else:
|
||||
return ""
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
|
||||
def generate(self, p) -> str:
|
||||
return super().generate(p)
|
||||
|
@ -14,20 +14,26 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str
|
||||
reasoning: str
|
||||
thoughts: str
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
command, thoughts, speak, reasoning = (
|
||||
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||
print(clean_json_str)
|
||||
if not clean_json_str:
|
||||
raise ValueError("model server response not have json!")
|
||||
try:
|
||||
response = json.loads(clean_json_str)
|
||||
except Exception as e:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
command, thoughts, speak = (
|
||||
response["command"],
|
||||
response["thoughts"],
|
||||
response["speak"],
|
||||
response["reasoning"],
|
||||
response["speak"]
|
||||
)
|
||||
return PluginAction(command, speak, reasoning, thoughts)
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
|
@ -10,7 +10,7 @@ from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||
|
||||
PROMPT_SUFFIX = """
|
||||
Goals:
|
||||
@ -20,25 +20,22 @@ Goals:
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Constraints:
|
||||
Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
Reflect on past decisions and strategies to refine your approach.
|
||||
Constructively self-criticize your big-picture behavior constantly.
|
||||
{constraints}
|
||||
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
{constraints}
|
||||
|
||||
Commands:
|
||||
{commands_infos}
|
||||
{commands_infos}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
|
||||
PROMPT_RESPONSE = """
|
||||
Please response strictly according to the following json format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = {
|
||||
"thoughts": "thought text",
|
||||
"reasoning": "reasoning",
|
||||
"speak": "thoughts summary to say to user",
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
@ -84,6 +84,11 @@ class ModelWorker:
|
||||
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
worker = ModelWorker(
|
||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@ -157,11 +162,4 @@ def embeddings(prompt_request: EmbeddingRequest):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
print(model_path, DEVICE)
|
||||
|
||||
worker = ModelWorker(
|
||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||
)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
||||
|
@ -35,7 +35,7 @@ from pilot.conversation import (
|
||||
chat_mode_title,
|
||||
default_conversation,
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
@ -658,7 +658,7 @@ def signal_handler(sig, frame):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
@ -670,6 +670,7 @@ if __name__ == "__main__":
|
||||
# 配置初始化
|
||||
cfg = Config()
|
||||
|
||||
load_native_plugins(cfg)
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
|
@ -47,7 +47,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf"
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
|
||||
)
|
||||
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
|
@ -17,7 +17,7 @@ importlib-resources==5.12.0
|
||||
|
||||
sqlparse==0.4.4
|
||||
kiwisolver==1.4.4
|
||||
matplotlib==3.7.0
|
||||
matplotlib==3.7.1
|
||||
multidict==6.0.4
|
||||
packaging==23.0
|
||||
psutil==5.9.4
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.plugins import (
|
||||
denylist_allowlist_check,
|
||||
inspect_zip_for_modules,
|
||||
|
Loading…
Reference in New Issue
Block a user