default scene change

This commit is contained in:
yhjun1026 2023-06-08 20:08:51 +08:00
parent ab1e3f51eb
commit 8851ab9d45
20 changed files with 1353 additions and 162 deletions

2
.gitignore vendored
View File

@ -46,7 +46,7 @@ MANIFEST
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
# Unit test_py / coverage reports
htmlcov/
.tox/
.nox/

File diff suppressed because it is too large Load Diff

View File

@ -64,59 +64,8 @@ class Database:
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
# self._all_tables = set(
# self._inspector.get_table_names(schema=schema)
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
# )
# self._include_tables = set(include_tables) if include_tables else set()
# if self._include_tables:
# missing_tables = self._include_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"include_tables {missing_tables} not found in database"
# )
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
# if self._ignore_tables:
# missing_tables = self._ignore_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"ignore_tables {missing_tables} not found in database"
# )
# usable_tables = self.get_usable_table_names()
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
# if not isinstance(sample_rows_in_table_info, int):
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
# if not isinstance(self._custom_table_info, dict):
# raise TypeError(
# "table_info must be a dictionary with table names as keys and the "
# "desired table info as values"
# )
# # only keep the tables that are also present in the database
# intersection = set(self._custom_table_info).intersection(self._all_tables)
# self._custom_table_info = dict(
# (table, self._custom_table_info[table])
# for table in self._custom_table_info
# if table in intersection
# )
# self._metadata = metadata or MetaData()
# # # including view support if view_support = true
# self._metadata.reflect(
# views=view_support,
# bind=self._engine,
# only=list(self._usable_tables),
# schema=self._schema,
# )
@classmethod
def from_uri(

View File

@ -36,7 +36,7 @@ class Config(metaclass=Singleton):
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
)
# This is a proxy server, just for test. we will remove this later.
# This is a proxy server, just for test_py. we will remove this later.
self.proxy_api_key = os.getenv("PROXY_API_KEY")
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
@ -112,6 +112,7 @@ class Config(metaclass=Singleton):
### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")

View File

@ -1,8 +1,35 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.configs.config import Config
class ClickHouseConnector:
CFG = Config()
class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector"""
pass
type: str = "DUCKDB"
driver: str = "duckdb"
file_path: str
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
@classmethod
def from_uri_db(cls, db_path: str,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -0,0 +1,38 @@
from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.configs.config import Config
CFG = Config()
class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData
Args:
Usage:
"""
type: str = "DUCKDB"
driver: str = "duckdb"
file_path: str
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
@classmethod
def from_uri_db(cls, db_path: str,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -1,8 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
class ElasticSearchConnector:
"""ElasticSearchConnector"""
pass

View File

@ -1,8 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
class MongoConnector:
"""MongoConnector is a class which connect to mongo and chat with LLM"""
pass

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData
Args:
Usage:
"""
type: str = "MSSQL"
dialect: str = "mssql"
driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"]

View File

@ -1,17 +1,23 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Any
import pymysql
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData For LLM Prompt
"""Connect MySQL Database fetch MetaData
Args:
Usage:
"""
type: str = "MySQL"
connect_url = "mysql+pymysql://"
dialect: str = "mysql"
driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector:
class OracleConnector(RDBMSDatabase):
"""OracleConnector"""
type: str = "ORACLE"
pass
driver: str = "oracle"
default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector"""
class PostgresConnector:
"""PostgresConnector is a class which Connector to chat with LLM"""
type: str = "POSTGRESQL"
driver: str = "postgresql"
pass
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -0,0 +1,54 @@
from pilot.configs.config import Config
import pandas as pd
from sqlalchemy import create_engine, pool
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
CFG = Config()
if __name__ == "__main__":
# 创建连接池
engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
# 从连接池中获取连接
# 归还连接到连接池中
# 执行SQL语句并将结果转化为DataFrame
query = "SELECT * FROM users"
df = pd.read_sql(query, engine.connect())
df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
# 导出为HTML文件
with open('report.html', 'w') as f:
f.write(df.style.render())
# # 设置中文字体
# font = FontProperties(fname='SimHei.ttf', size=14)
#
# colors = np.random.rand(df.shape[0])
# df.plot.scatter(x='city', y='user_name', c=colors)
# plt.show()
# 查看DataFrame
print(df.head())
# 创建数据
x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
y_data = [820, 932, 901, 934, 1290, 1330, 1320]
# 生成图表
bar = (
Bar()
.add_xaxis(x_data)
.add_yaxis("销售额", y_data)
.set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
)
# 生成HTML文件
bar.render('report.html')

View File

@ -19,6 +19,9 @@ from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
from pilot.connections.base import BaseConnect
from pilot.configs.config import Config
CFG = Config()
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
@ -32,16 +35,13 @@ class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
):
"""Create engine from database URI."""
self._engine = engine
@ -55,73 +55,33 @@ class RDBMSDatabase(BaseConnect):
self._db_sessions = Session
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
# self._all_tables = set(
# self._inspector.get_table_names(schema=schema)
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
# )
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(cls,
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True})
# self._include_tables = set(include_tables) if include_tables else set()
# if self._include_tables:
# missing_tables = self._include_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"include_tables {missing_tables} not found in database"
# )
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
# if self._ignore_tables:
# missing_tables = self._ignore_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"ignore_tables {missing_tables} not found in database"
# )
# usable_tables = self.get_usable_table_names()
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
# if not isinstance(sample_rows_in_table_info, int):
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
# self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
# if not isinstance(self._custom_table_info, dict):
# raise TypeError(
# "table_info must be a dictionary with table names as keys and the "
# "desired table info as values"
# )
# # only keep the tables that are also present in the database
# intersection = set(self._custom_table_info).intersection(self._all_tables)
# self._custom_table_info = dict(
# (table, self._custom_table_info[table])
# for table in self._custom_table_info
# if table in intersection
# )
# self._metadata = metadata or MetaData()
# # # including view support if view_support = true
# self._metadata.reflect(
# views=view_support,
# bind=self._engine,
# only=list(self._usable_tables),
# schema=self._schema,
# )
@classmethod
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None,
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(
CFG.LOCAL_DB_PORT)
if cls.dialect:
db_url = cls.dialect + "+" + db_url
if db_name:
db_url = db_url + "/" + db_name
return cls.from_uri(db_url, engine_args, **kwargs)
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@ -207,7 +167,7 @@ class RDBMSDatabase(BaseConnect):
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@ -220,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"

View File

@ -168,6 +168,6 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later
# just for test_py, remove this later
register_llm_model_adapters(ProxyllmAdapter)
register_llm_model_adapters(BaseLLMAdaper)

View File

@ -62,7 +62,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
history.append(last_user_input)
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"model": "gpt-3.5-turbo", # just for test_py, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),

View File

@ -110,7 +110,7 @@ train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)
train_data = train_val["train"].map(generate_and_tokenize_prompt)
val_data = train_val["test"].map(generate_and_tokenize_prompt)
val_data = train_val["test_py"].map(generate_and_tokenize_prompt)
# Training
LORA_R = 8

View File

@ -70,10 +70,8 @@ class BaseChat(ABC):
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
### can configurable storage methods
# self.memory = MemHistoryMemory(chat_session_id)
self.memory = MemHistoryMemory(chat_session_id)
## TEST
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value

View File

@ -47,7 +47,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
if __name__ == "__main__":
filepath = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf"
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
)
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
docs = loader.load()