mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 00:07:45 +00:00
default scene change
This commit is contained in:
parent
ab1e3f51eb
commit
8851ab9d45
2
.gitignore
vendored
2
.gitignore
vendored
@ -46,7 +46,7 @@ MANIFEST
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
# Unit test_py / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
1145
logsDbChatOutputParser.log.2023-06-06
Normal file
1145
logsDbChatOutputParser.log.2023-06-06
Normal file
File diff suppressed because it is too large
Load Diff
@ -64,59 +64,8 @@ class Database:
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
# self._all_tables = set(
|
||||
# self._inspector.get_table_names(schema=schema)
|
||||
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||
# )
|
||||
|
||||
# self._include_tables = set(include_tables) if include_tables else set()
|
||||
# if self._include_tables:
|
||||
# missing_tables = self._include_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"include_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||
# if self._ignore_tables:
|
||||
# missing_tables = self._ignore_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"ignore_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# usable_tables = self.get_usable_table_names()
|
||||
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||
|
||||
# if not isinstance(sample_rows_in_table_info, int):
|
||||
# raise TypeError("sample_rows_in_table_info must be an integer")
|
||||
#
|
||||
# self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
#
|
||||
# self._custom_table_info = custom_table_info
|
||||
# if self._custom_table_info:
|
||||
# if not isinstance(self._custom_table_info, dict):
|
||||
# raise TypeError(
|
||||
# "table_info must be a dictionary with table names as keys and the "
|
||||
# "desired table info as values"
|
||||
# )
|
||||
# # only keep the tables that are also present in the database
|
||||
# intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||
# self._custom_table_info = dict(
|
||||
# (table, self._custom_table_info[table])
|
||||
# for table in self._custom_table_info
|
||||
# if table in intersection
|
||||
# )
|
||||
|
||||
# self._metadata = metadata or MetaData()
|
||||
# # # including view support if view_support = true
|
||||
# self._metadata.reflect(
|
||||
# views=view_support,
|
||||
# bind=self._engine,
|
||||
# only=list(self._usable_tables),
|
||||
# schema=self._schema,
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
|
@ -36,7 +36,7 @@ class Config(metaclass=Singleton):
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
# This is a proxy server, just for test. we will remove this later.
|
||||
# This is a proxy server, just for test_py. we will remove this later.
|
||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
@ -112,6 +112,7 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
@ -1,8 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
class ClickHouseConnector:
|
||||
CFG = Config()
|
||||
|
||||
class ClickHouseConnector(RDBMSDatabase):
|
||||
"""ClickHouseConnector"""
|
||||
|
||||
pass
|
||||
type: str = "DUCKDB"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
38
pilot/connections/rdbms/duckdb.py
Normal file
38
pilot/connections/rdbms/duckdb.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "DUCKDB"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(cls, db_path: str,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class ElasticSearchConnector:
|
||||
"""ElasticSearchConnector"""
|
||||
|
||||
pass
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class MongoConnector:
|
||||
"""MongoConnector is a class which connect to mongo and chat with LLM"""
|
||||
|
||||
pass
|
23
pilot/connections/rdbms/mssql.py
Normal file
23
pilot/connections/rdbms/mssql.py
Normal file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MSSQL"
|
||||
dialect: str = "mssql"
|
||||
driver: str = "pyodbc"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]
|
||||
|
||||
|
@ -1,17 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
import pymysql
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||
"""Connect MySQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MySQL"
|
||||
connect_url = "mysql+pymysql://"
|
||||
dialect: str = "mysql"
|
||||
driver: str = "pymysql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class OracleConnector:
|
||||
class OracleConnector(RDBMSDatabase):
|
||||
"""OracleConnector"""
|
||||
type: str = "ORACLE"
|
||||
|
||||
pass
|
||||
driver: str = "oracle"
|
||||
|
||||
default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
class PostgresConnector(RDBMSDatabase):
|
||||
"""PostgresConnector is a class which Connector"""
|
||||
|
||||
class PostgresConnector:
|
||||
"""PostgresConnector is a class which Connector to chat with LLM"""
|
||||
type: str = "POSTGRESQL"
|
||||
driver: str = "postgresql"
|
||||
|
||||
pass
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
0
pilot/connections/rdbms/py_study/__init__.py
Normal file
0
pilot/connections/rdbms/py_study/__init__.py
Normal file
54
pilot/connections/rdbms/py_study/pd_study.py
Normal file
54
pilot/connections/rdbms/py_study/pd_study.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, pool
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建连接池
|
||||
engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
|
||||
|
||||
# 从连接池中获取连接
|
||||
|
||||
|
||||
# 归还连接到连接池中
|
||||
|
||||
# 执行SQL语句并将结果转化为DataFrame
|
||||
query = "SELECT * FROM users"
|
||||
df = pd.read_sql(query, engine.connect())
|
||||
df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
|
||||
# 导出为HTML文件
|
||||
with open('report.html', 'w') as f:
|
||||
f.write(df.style.render())
|
||||
|
||||
# # 设置中文字体
|
||||
# font = FontProperties(fname='SimHei.ttf', size=14)
|
||||
#
|
||||
# colors = np.random.rand(df.shape[0])
|
||||
# df.plot.scatter(x='city', y='user_name', c=colors)
|
||||
# plt.show()
|
||||
|
||||
# 查看DataFrame
|
||||
print(df.head())
|
||||
|
||||
|
||||
# 创建数据
|
||||
x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
y_data = [820, 932, 901, 934, 1290, 1330, 1320]
|
||||
|
||||
# 生成图表
|
||||
bar = (
|
||||
Bar()
|
||||
.add_xaxis(x_data)
|
||||
.add_yaxis("销售额", y_data)
|
||||
.set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
|
||||
)
|
||||
|
||||
# 生成HTML文件
|
||||
bar.render('report.html')
|
@ -19,6 +19,9 @@ from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
from pilot.connections.base import BaseConnect
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
@ -32,16 +35,13 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -55,73 +55,33 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
self._db_sessions = Session
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
# self._all_tables = set(
|
||||
# self._inspector.get_table_names(schema=schema)
|
||||
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||
# )
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(cls,
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
# self._include_tables = set(include_tables) if include_tables else set()
|
||||
# if self._include_tables:
|
||||
# missing_tables = self._include_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"include_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||
# if self._ignore_tables:
|
||||
# missing_tables = self._ignore_tables - self._all_tables
|
||||
# if missing_tables:
|
||||
# raise ValueError(
|
||||
# f"ignore_tables {missing_tables} not found in database"
|
||||
# )
|
||||
# usable_tables = self.get_usable_table_names()
|
||||
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||
|
||||
# if not isinstance(sample_rows_in_table_info, int):
|
||||
# raise TypeError("sample_rows_in_table_info must be an integer")
|
||||
#
|
||||
# self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
# self._indexes_in_table_info = indexes_in_table_info
|
||||
#
|
||||
# self._custom_table_info = custom_table_info
|
||||
# if self._custom_table_info:
|
||||
# if not isinstance(self._custom_table_info, dict):
|
||||
# raise TypeError(
|
||||
# "table_info must be a dictionary with table names as keys and the "
|
||||
# "desired table info as values"
|
||||
# )
|
||||
# # only keep the tables that are also present in the database
|
||||
# intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||
# self._custom_table_info = dict(
|
||||
# (table, self._custom_table_info[table])
|
||||
# for table in self._custom_table_info
|
||||
# if table in intersection
|
||||
# )
|
||||
|
||||
# self._metadata = metadata or MetaData()
|
||||
# # # including view support if view_support = true
|
||||
# self._metadata.reflect(
|
||||
# views=view_support,
|
||||
# bind=self._engine,
|
||||
# only=list(self._usable_tables),
|
||||
# schema=self._schema,
|
||||
# )
|
||||
@classmethod
|
||||
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
|
||||
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
|
||||
CFG.LOCAL_DB_PORT)
|
||||
if cls.dialect:
|
||||
db_url = cls.dialect + "+" + db_url
|
||||
if db_name:
|
||||
db_url = db_url + "/" + db_name
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -207,7 +167,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -220,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
|
@ -168,6 +168,6 @@ register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test, remove this later
|
||||
# just for test_py, remove this later
|
||||
register_llm_model_adapters(ProxyllmAdapter)
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
|
@ -62,7 +62,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
history.append(last_user_input)
|
||||
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"model": "gpt-3.5-turbo", # just for test_py, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
|
@ -110,7 +110,7 @@ train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)
|
||||
|
||||
train_data = train_val["train"].map(generate_and_tokenize_prompt)
|
||||
|
||||
val_data = train_val["test"].map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test_py"].map(generate_and_tokenize_prompt)
|
||||
|
||||
# Training
|
||||
LORA_R = 8
|
||||
|
@ -70,10 +70,8 @@ class BaseChat(ABC):
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
### can configurable storage methods
|
||||
# self.memory = MemHistoryMemory(chat_session_id)
|
||||
self.memory = MemHistoryMemory(chat_session_id)
|
||||
|
||||
## TEST
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
|
@ -47,7 +47,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf"
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
|
||||
)
|
||||
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
|
Loading…
Reference in New Issue
Block a user