feat: Support SQLite connection

This commit is contained in:
FangYin Cheng 2023-08-11 05:36:00 +08:00
parent 22658e36cf
commit 0859f36a89
10 changed files with 317 additions and 23 deletions

View File

@ -54,11 +54,17 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
#*******************************************************************#
#** DATABASE SETTINGS **#
#*******************************************************************#
### MYSQL database(Current default database)
LOCAL_DB_TYPE=mysql
LOCAL_DB_USER=root
LOCAL_DB_PASSWORD=aa12345678
LOCAL_DB_HOST=127.0.0.1
LOCAL_DB_PORT=3306
### SQLite database (TODO: SQLite database will become the default database configuration when it is stable.)
# LOCAL_DB_PATH=data/default_sqlite.db
# LOCAL_DB_TYPE=sqlite
### MILVUS
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)

1
.gitignore vendored
View File

@ -145,6 +145,7 @@ nltk_data
.vectordb
pilot/data/
pilot/nltk_data
pilot/mock_datas/db-gpt-test.db.wal
logswebserver.log.*
.history/*

View File

@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
FROM ${BASE_IMAGE}
ARG BASE_IMAGE
RUN apt-get update && apt-get install -y git python3 pip wget \
RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \
&& apt-get clean
ARG BUILD_LOCAL_CODE="false"

View File

@ -1,5 +1,6 @@
from enum import auto, Enum
from typing import List, Any
import os
class SeparatorStyle(Enum):
@ -24,6 +25,7 @@ class DBType(Enum):
Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True)
SQLite = DbInfo("sqlite", True)
Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql")
@ -40,3 +42,12 @@ class DBType(Enum):
if item.value() == db_type:
return item
return None
@staticmethod
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
"""Parse out the database name of the embedded database from the file path"""
base_name = os.path.basename(local_db_path)
db_name = os.path.splitext(base_name)[0]
if "." in db_name:
db_name = os.path.splitext(db_name)[0]
return db_type + "_" + db_name

View File

@ -121,8 +121,12 @@ class Config(metaclass=Singleton):
)
### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1"
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")

View File

@ -8,6 +8,7 @@ from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.singleton import Singleton
@ -89,12 +90,21 @@ class ConnectManager:
"",
)
if CFG.LOCAL_DB_PATH:
# default file db is duckdb
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
db_name = CFG.LOCAL_DB_NAME
db_type = CFG.LOCAL_DB_TYPE
db_path = CFG.LOCAL_DB_PATH
if not db_name:
if db_type is None or db_type == DBType.DuckDb.value():
# file db is duckdb
db_name = self.storage.get_file_db_name(db_path)
db_type = DBType.DuckDb.value()
else:
db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
if db_name:
self.storage.add_file_db(
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH
print(
f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
)
self.storage.add_file_db(db_name, db_type, db_path)
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)

View File

@ -21,6 +21,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
from pilot.common.schema import DBType
from pilot.connections.base import BaseConnect
from pilot.configs.config import Config
@ -78,16 +79,7 @@ class RDBMSDatabase(BaseConnect):
self._metadata = MetaData()
self._metadata.reflect(bind=self._engine)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=self._engine.url.database)
+ (
self._inspector.get_view_names(schema=self._engine.url.database)
if self.view_support
else []
)
)
self._all_tables = self._sync_tables_from_db()
@classmethod
def from_uri_db(
@ -128,6 +120,26 @@ class RDBMSDatabase(BaseConnect):
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema
_schema = (
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=_schema)
+ (
self._inspector.get_view_names(schema=_schema)
if self.view_support
else []
)
)
return self._all_tables
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
@ -250,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def __write(self, session, write_sql):
def _write(self, session, write_sql):
print(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = session.execute(text(write_sql))
@ -279,7 +291,7 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
result = result = [cursor.fetchone()] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys())
@ -305,11 +317,10 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
result = [cursor.fetchone()] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
result = list(result)
return field_names, result
@ -323,7 +334,7 @@ class RDBMSDatabase(BaseConnect):
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
self._write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
@ -332,6 +343,11 @@ class RDBMSDatabase(BaseConnect):
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
_show_columns_sql = (
f"PRAGMA table_info({table_name})"
if self.db_type == "sqlite"
else f"SHOW COLUMNS FROM {table_name}"
)
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
@ -339,10 +355,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
return self.__query(session, _show_columns_sql)
return result
else:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
return self.__query(session, _show_columns_sql)
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Any, Iterable
from sqlalchemy import create_engine, text
from pilot.connections.rdbms.base import RDBMSDatabase
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
Args:
Usage:
"""
db_type: str = "sqlite"
db_dialect: str = "sqlite"
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
indexes = cursor.fetchall()
return [(index[1], index[3]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
cursor = self.session.execute(
text(
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
)
)
ans = cursor.fetchall()
return ans[0][0]
def get_fields(self, table_name):
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()
print(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_database_list(self):
return []
def get_database_names(self):
return []
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
view_results = self.session.execute(
"SELECT name FROM sqlite_master WHERE type='view'"
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def _write(self, session, write_sql):
print(f"Write[{write_sql}]")
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def get_table_comments(self, db_name=None):
cursor = self.session.execute(
text(
f"""
SELECT name, sql FROM sqlite_master WHERE type='table'
"""
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT name FROM sqlite_master WHERE type='table'
"""
cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall()
results = []
for row in tables_results:
table_name = row[0]
_sql = f"""
PRAGMA table_info({table_name})
"""
cursor_colums = self.session.execute(text(_sql))
colum_results = cursor_colums.fetchall()
table_colums = []
for row_col in colum_results:
field_info = list(row_col)
table_colums.append(field_info[1])
results.append(f"{table_name}({','.join(table_colums)});")
return results

View File

@ -0,0 +1,123 @@
"""
Run unit test with command: pytest pilot/connections/rdbms/tests/test_conn_sqlite.py
"""
import pytest
import tempfile
import os
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
@pytest.fixture
def db():
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
temp_db_file.close()
conn = SQLiteConnect.from_file_path(temp_db_file.name)
yield conn
os.unlink(temp_db_file.name)
def test_get_table_names(db):
assert list(db.get_table_names()) == []
def test_get_table_info(db):
assert db.get_table_info() == ""
def test_get_table_info_with_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER);")
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info
def test_run_sql(db):
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
def test_run_no_throw(db):
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
def test_get_indexes(db):
db.run(db.session, "CREATE TABLE test (name TEXT);")
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")]
def test_get_indexes_empty(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_indexes("test") == []
def test_get_show_create_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert (
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
)
def test_get_fields(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
def test_get_charset(db):
assert db.get_charset() == "UTF-8"
def test_get_collation(db):
assert db.get_collation() == "UTF-8"
def test_table_simple_info(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.table_simple_info() == ["test(id);"]
def test_get_table_info_no_throw(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
def test_query_ex(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
assert field_names == ["id"]
assert result == [(1,)]
def test_convert_sql_write_to_select(db):
# TODO
pass
def test_get_grants(db):
assert db.get_grants() == []
def test_get_users(db):
assert db.get_users() == []
def test_get_table_comments(db):
assert db.get_table_comments() == []
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_comments() == [
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
]
def test_get_database_list(db):
db.get_database_list() == []
def test_get_database_names(db):
db.get_database_names() == []