mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat: Support SQLite connection
This commit is contained in:
parent
22658e36cf
commit
0859f36a89
@ -54,11 +54,17 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
#*******************************************************************#
|
||||
#** DATABASE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
### MYSQL database(Current default database)
|
||||
LOCAL_DB_TYPE=mysql
|
||||
LOCAL_DB_USER=root
|
||||
LOCAL_DB_PASSWORD=aa12345678
|
||||
LOCAL_DB_HOST=127.0.0.1
|
||||
LOCAL_DB_PORT=3306
|
||||
|
||||
### SQLite database (TODO: SQLite database will become the default database configuration when it is stable.)
|
||||
# LOCAL_DB_PATH=data/default_sqlite.db
|
||||
# LOCAL_DB_TYPE=sqlite
|
||||
|
||||
|
||||
### MILVUS
|
||||
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -145,6 +145,7 @@ nltk_data
|
||||
.vectordb
|
||||
pilot/data/
|
||||
pilot/nltk_data
|
||||
pilot/mock_datas/db-gpt-test.db.wal
|
||||
|
||||
logswebserver.log.*
|
||||
.history/*
|
||||
|
@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
|
||||
FROM ${BASE_IMAGE}
|
||||
ARG BASE_IMAGE
|
||||
|
||||
RUN apt-get update && apt-get install -y git python3 pip wget \
|
||||
RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \
|
||||
&& apt-get clean
|
||||
|
||||
ARG BUILD_LOCAL_CODE="false"
|
||||
|
@ -1,5 +1,6 @@
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
import os
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
@ -24,6 +25,7 @@ class DBType(Enum):
|
||||
Mysql = DbInfo("mysql")
|
||||
OCeanBase = DbInfo("oceanbase")
|
||||
DuckDb = DbInfo("duckdb", True)
|
||||
SQLite = DbInfo("sqlite", True)
|
||||
Oracle = DbInfo("oracle")
|
||||
MSSQL = DbInfo("mssql")
|
||||
Postgresql = DbInfo("postgresql")
|
||||
@ -40,3 +42,12 @@ class DBType(Enum):
|
||||
if item.value() == db_type:
|
||||
return item
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
|
||||
"""Parse out the database name of the embedded database from the file path"""
|
||||
base_name = os.path.basename(local_db_path)
|
||||
db_name = os.path.splitext(base_name)[0]
|
||||
if "." in db_name:
|
||||
db_name = os.path.splitext(db_name)[0]
|
||||
return db_type + "_" + db_name
|
||||
|
@ -121,8 +121,12 @@ class Config(metaclass=Singleton):
|
||||
)
|
||||
|
||||
### default Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE")
|
||||
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
|
||||
self.LOCAL_DB_HOST = "127.0.0.1"
|
||||
|
||||
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
|
@ -8,6 +8,7 @@ from pilot.connections.base import BaseConnect
|
||||
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
|
||||
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.singleton import Singleton
|
||||
@ -89,12 +90,21 @@ class ConnectManager:
|
||||
"",
|
||||
)
|
||||
if CFG.LOCAL_DB_PATH:
|
||||
# default file db is duckdb
|
||||
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
|
||||
db_name = CFG.LOCAL_DB_NAME
|
||||
db_type = CFG.LOCAL_DB_TYPE
|
||||
db_path = CFG.LOCAL_DB_PATH
|
||||
if not db_name:
|
||||
if db_type is None or db_type == DBType.DuckDb.value():
|
||||
# file db is duckdb
|
||||
db_name = self.storage.get_file_db_name(db_path)
|
||||
db_type = DBType.DuckDb.value()
|
||||
else:
|
||||
db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
|
||||
if db_name:
|
||||
self.storage.add_file_db(
|
||||
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH
|
||||
print(
|
||||
f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
|
||||
)
|
||||
self.storage.add_file_db(db_name, db_type, db_path)
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
|
@ -21,6 +21,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.connections.base import BaseConnect
|
||||
from pilot.configs.config import Config
|
||||
|
||||
@ -78,16 +79,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
self._metadata = MetaData()
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=self._engine.url.database)
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=self._engine.url.database)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
self._all_tables = self._sync_tables_from_db()
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
@ -128,6 +120,26 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
_schema = (
|
||||
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
|
||||
)
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=_schema)
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=_schema)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
return self._all_tables
|
||||
|
||||
def get_usable_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
@ -250,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, session, write_sql):
|
||||
def _write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = session.execute(text(write_sql))
|
||||
@ -279,7 +291,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
result = result = [cursor.fetchone()] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
@ -305,11 +317,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
result = [cursor.fetchone()] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
@ -323,7 +334,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(session, command, fetch)
|
||||
else:
|
||||
self.__write(session, command)
|
||||
self._write(session, command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(session, select_sql)
|
||||
@ -332,6 +343,11 @@ class RDBMSDatabase(BaseConnect):
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = session.execute(text(command))
|
||||
session.commit()
|
||||
_show_columns_sql = (
|
||||
f"PRAGMA table_info({table_name})"
|
||||
if self.db_type == "sqlite"
|
||||
else f"SHOW COLUMNS FROM {table_name}"
|
||||
)
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
@ -339,10 +355,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query(session, _show_columns_sql)
|
||||
return result
|
||||
else:
|
||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query(session, _show_columns_sql)
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
123
pilot/connections/rdbms/conn_sqlite.py
Normal file
123
pilot/connections/rdbms/conn_sqlite.py
Normal file
@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class SQLiteConnect(RDBMSDatabase):
|
||||
"""Connect SQLite Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "sqlite"
|
||||
db_dialect: str = "sqlite"
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[1], index[3]) for index in indexes]
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
|
||||
)
|
||||
)
|
||||
ans = cursor.fetchall()
|
||||
return ans[0][0]
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
|
||||
fields = cursor.fetchall()
|
||||
print(fields)
|
||||
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='view'"
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def _write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
PRAGMA table_info({table_name})
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[1])
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
0
pilot/connections/rdbms/tests/__init__.py
Normal file
0
pilot/connections/rdbms/tests/__init__.py
Normal file
123
pilot/connections/rdbms/tests/test_conn_sqlite.py
Normal file
123
pilot/connections/rdbms/tests/test_conn_sqlite.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
Run unit test with command: pytest pilot/connections/rdbms/tests/test_conn_sqlite.py
|
||||
"""
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_db_file.close()
|
||||
conn = SQLiteConnect.from_file_path(temp_db_file.name)
|
||||
yield conn
|
||||
os.unlink(temp_db_file.name)
|
||||
|
||||
|
||||
def test_get_table_names(db):
|
||||
assert list(db.get_table_names()) == []
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
assert db.get_table_info() == ""
|
||||
|
||||
|
||||
def test_get_table_info_with_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
print(db._sync_tables_from_db())
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE test" in table_info
|
||||
|
||||
|
||||
def test_run_sql(db):
|
||||
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
|
||||
|
||||
|
||||
def test_get_indexes(db):
|
||||
db.run(db.session, "CREATE TABLE test (name TEXT);")
|
||||
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
|
||||
assert db.get_indexes("test") == [("idx_name", "c")]
|
||||
|
||||
|
||||
def test_get_indexes_empty(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_indexes("test") == []
|
||||
|
||||
|
||||
def test_get_show_create_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert (
|
||||
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
||||
)
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
assert db.get_charset() == "UTF-8"
|
||||
|
||||
|
||||
def test_get_collation(db):
|
||||
assert db.get_collation() == "UTF-8"
|
||||
|
||||
|
||||
def test_table_simple_info(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.table_simple_info() == ["test(id);"]
|
||||
|
||||
|
||||
def test_get_table_info_no_throw(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
|
||||
|
||||
|
||||
def test_query_ex(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run(db.session, "insert into test(id) values (1)")
|
||||
db.run(db.session, "insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,), (2,)]
|
||||
|
||||
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,)]
|
||||
|
||||
|
||||
def test_convert_sql_write_to_select(db):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
def test_get_grants(db):
|
||||
assert db.get_grants() == []
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
assert db.get_users() == []
|
||||
|
||||
|
||||
def test_get_table_comments(db):
|
||||
assert db.get_table_comments() == []
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_comments() == [
|
||||
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||
]
|
||||
|
||||
|
||||
def test_get_database_list(db):
|
||||
db.get_database_list() == []
|
||||
|
||||
|
||||
def test_get_database_names(db):
|
||||
db.get_database_names() == []
|
Loading…
Reference in New Issue
Block a user