diff --git a/.env.template b/.env.template index f1418323d..2058c36a3 100644 --- a/.env.template +++ b/.env.template @@ -54,11 +54,17 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5 #*******************************************************************# #** DATABASE SETTINGS **# #*******************************************************************# +### MYSQL database(Current default database) +LOCAL_DB_TYPE=mysql LOCAL_DB_USER=root LOCAL_DB_PASSWORD=aa12345678 LOCAL_DB_HOST=127.0.0.1 LOCAL_DB_PORT=3306 +### SQLite database (TODO: SQLite database will become the default database configuration when it is stable.) +# LOCAL_DB_PATH=data/default_sqlite.db +# LOCAL_DB_TYPE=sqlite + ### MILVUS ## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530) diff --git a/.gitignore b/.gitignore index 2e62e9841..55e18c91f 100644 --- a/.gitignore +++ b/.gitignore @@ -145,6 +145,7 @@ nltk_data .vectordb pilot/data/ pilot/nltk_data +pilot/mock_datas/db-gpt-test.db.wal logswebserver.log.* .history/* diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index 5144c1348..d65fc2a98 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04" FROM ${BASE_IMAGE} ARG BASE_IMAGE -RUN apt-get update && apt-get install -y git python3 pip wget \ +RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \ && apt-get clean ARG BUILD_LOCAL_CODE="false" diff --git a/pilot/common/schema.py b/pilot/common/schema.py index de035d687..eb25fe5e9 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -1,5 +1,6 @@ from enum import auto, Enum from typing import List, Any +import os class SeparatorStyle(Enum): @@ -24,6 +25,7 @@ class DBType(Enum): Mysql = DbInfo("mysql") OCeanBase = DbInfo("oceanbase") DuckDb = DbInfo("duckdb", True) + SQLite = DbInfo("sqlite", True) Oracle = DbInfo("oracle") MSSQL = DbInfo("mssql") Postgresql = DbInfo("postgresql") @@ -40,3 +42,12 @@ class DBType(Enum): if item.value() == db_type: return item return None + + @staticmethod + def parse_file_db_name_from_path(db_type: str, local_db_path: str): + """Parse out the database name of the embedded database from the file path""" + base_name = os.path.basename(local_db_path) + db_name = os.path.splitext(base_name)[0] + if "." in db_name: + db_name = os.path.splitext(db_name)[0] + return db_type + "_" + db_name diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 5846a7c1a..609d946c2 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -121,8 +121,12 @@ class Config(metaclass=Singleton): ) ### default Local database connection configuration - self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") + self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") + self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE") + if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "": + self.LOCAL_DB_HOST = "127.0.0.1" + self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME") self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 38d7dd9f0..19b9e64cf 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -8,6 +8,7 @@ from pilot.connections.base import BaseConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_duckdb import DuckDbConnect +from pilot.connections.rdbms.conn_sqlite import SQLiteConnect from pilot.connections.rdbms.conn_mssql import MSSQLConnect from pilot.connections.rdbms.base import RDBMSDatabase from pilot.singleton import Singleton @@ -89,12 +90,21 @@ class ConnectManager: "", ) if CFG.LOCAL_DB_PATH: - # default file db is duckdb - db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH) + db_name = CFG.LOCAL_DB_NAME + db_type = CFG.LOCAL_DB_TYPE + db_path = CFG.LOCAL_DB_PATH + if not db_name: + if db_type is None or db_type == DBType.DuckDb.value(): + # file db is duckdb + db_name = self.storage.get_file_db_name(db_path) + db_type = DBType.DuckDb.value() + else: + db_name = DBType.parse_file_db_name_from_path(db_type, db_path) if db_name: - self.storage.add_file_db( - db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH + print( + f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}" ) + self.storage.add_file_db(db_name, db_type, db_path) def get_connect(self, db_name): db_config = self.storage.get_db_config(db_name) diff --git a/pilot/connections/rdbms/base.py b/pilot/connections/rdbms/base.py index d2a853589..e1f96f155 100644 --- a/pilot/connections/rdbms/base.py +++ b/pilot/connections/rdbms/base.py @@ -21,6 +21,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable from sqlalchemy.orm import sessionmaker, scoped_session +from pilot.common.schema import DBType from pilot.connections.base import BaseConnect from pilot.configs.config import Config @@ -78,16 +79,7 @@ class RDBMSDatabase(BaseConnect): self._metadata = MetaData() self._metadata.reflect(bind=self._engine) - # including view support by adding the views as well as tables to the all - # tables list if view_support is True - self._all_tables = set( - self._inspector.get_table_names(schema=self._engine.url.database) - + ( - self._inspector.get_view_names(schema=self._engine.url.database) - if self.view_support - else [] - ) - ) + self._all_tables = self._sync_tables_from_db() @classmethod def from_uri_db( @@ -128,6 +120,26 @@ class RDBMSDatabase(BaseConnect): """Return string representation of dialect to use.""" return self._engine.dialect.name + def _sync_tables_from_db(self) -> Iterable[str]: + """Read table information from database""" + # TODO Use a background thread to refresh periodically + + # SQL will raise error with schema + _schema = ( + None if self.db_type == DBType.SQLite.value() else self._engine.url.database + ) + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + self._all_tables = set( + self._inspector.get_table_names(schema=_schema) + + ( + self._inspector.get_view_names(schema=_schema) + if self.view_support + else [] + ) + ) + return self._all_tables + def get_usable_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: @@ -250,7 +262,7 @@ class RDBMSDatabase(BaseConnect): """Format the error message""" return f"Error: {e}" - def __write(self, session, write_sql): + def _write(self, session, write_sql): print(f"Write[{write_sql}]") db_cache = self._engine.url.database result = session.execute(text(write_sql)) @@ -279,7 +291,7 @@ class RDBMSDatabase(BaseConnect): if fetch == "all": result = cursor.fetchall() elif fetch == "one": - result = cursor.fetchone()[0] # type: ignore + result = result = [cursor.fetchone()] # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") field_names = tuple(i[0:] for i in cursor.keys()) @@ -305,11 +317,10 @@ class RDBMSDatabase(BaseConnect): if fetch == "all": result = cursor.fetchall() elif fetch == "one": - result = cursor.fetchone()[0] # type: ignore + result = [cursor.fetchone()] # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") field_names = list(i[0:] for i in cursor.keys()) - result = list(result) return field_names, result @@ -323,7 +334,7 @@ class RDBMSDatabase(BaseConnect): if sql_type == "SELECT": return self.__query(session, command, fetch) else: - self.__write(session, command) + self._write(session, command) select_sql = self.convert_sql_write_to_select(command) print(f"write result query:{select_sql}") return self.__query(session, select_sql) @@ -332,6 +343,11 @@ class RDBMSDatabase(BaseConnect): print(f"DDL execution determines whether to enable through configuration ") cursor = session.execute(text(command)) session.commit() + _show_columns_sql = ( + f"PRAGMA table_info({table_name})" + if self.db_type == "sqlite" + else f"SHOW COLUMNS FROM {table_name}" + ) if cursor.returns_rows: result = cursor.fetchall() field_names = tuple(i[0:] for i in cursor.keys()) @@ -339,10 +355,10 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query(session, f"SHOW COLUMNS FROM {table_name}") + return self.__query(session, _show_columns_sql) return result else: - return self.__query(session, f"SHOW COLUMNS FROM {table_name}") + return self.__query(session, _show_columns_sql) def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. diff --git a/pilot/connections/rdbms/conn_sqlite.py b/pilot/connections/rdbms/conn_sqlite.py new file mode 100644 index 000000000..f4c8084a2 --- /dev/null +++ b/pilot/connections/rdbms/conn_sqlite.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from typing import Optional, Any, Iterable +from sqlalchemy import create_engine, text + +from pilot.connections.rdbms.base import RDBMSDatabase + + +class SQLiteConnect(RDBMSDatabase): + """Connect SQLite Database fetch MetaData + Args: + Usage: + """ + + db_type: str = "sqlite" + db_dialect: str = "sqlite" + + @classmethod + def from_file_path( + cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: + """Construct a SQLAlchemy engine from URI.""" + _engine_args = engine_args or {} + return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs) + + def get_indexes(self, table_name): + """Get table indexes about specified table.""" + cursor = self.session.execute(text(f"PRAGMA index_list({table_name})")) + indexes = cursor.fetchall() + return [(index[1], index[3]) for index in indexes] + + def get_show_create_table(self, table_name): + """Get table show create table about specified table.""" + cursor = self.session.execute( + text( + f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'" + ) + ) + ans = cursor.fetchall() + return ans[0][0] + + def get_fields(self, table_name): + """Get column fields about specified table.""" + cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')")) + fields = cursor.fetchall() + print(fields) + return [(field[1], field[2], field[3], field[4], field[5]) for field in fields] + + def get_users(self): + return [] + + def get_grants(self): + return [] + + def get_collation(self): + """Get collation.""" + return "UTF-8" + + def get_charset(self): + return "UTF-8" + + def get_database_list(self): + return [] + + def get_database_names(self): + return [] + + def _sync_tables_from_db(self) -> Iterable[str]: + table_results = self.session.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + view_results = self.session.execute( + "SELECT name FROM sqlite_master WHERE type='view'" + ) + table_results = set(row[0] for row in table_results) + view_results = set(row[0] for row in view_results) + self._all_tables = table_results.union(view_results) + self._metadata.reflect(bind=self._engine) + return self._all_tables + + def _write(self, session, write_sql): + print(f"Write[{write_sql}]") + result = session.execute(text(write_sql)) + session.commit() + # TODO Subsequent optimization of dynamically specified database submission loss target problem + print(f"SQL[{write_sql}], result:{result.rowcount}") + return result.rowcount + + def get_table_comments(self, db_name=None): + cursor = self.session.execute( + text( + f""" + SELECT name, sql FROM sqlite_master WHERE type='table' + """ + ) + ) + table_comments = cursor.fetchall() + return [ + (table_comment[0], table_comment[1]) for table_comment in table_comments + ] + + def table_simple_info(self) -> Iterable[str]: + _tables_sql = f""" + SELECT name FROM sqlite_master WHERE type='table' + """ + cursor = self.session.execute(text(_tables_sql)) + tables_results = cursor.fetchall() + results = [] + for row in tables_results: + table_name = row[0] + _sql = f""" + PRAGMA table_info({table_name}) + """ + cursor_colums = self.session.execute(text(_sql)) + colum_results = cursor_colums.fetchall() + table_colums = [] + for row_col in colum_results: + field_info = list(row_col) + table_colums.append(field_info[1]) + + results.append(f"{table_name}({','.join(table_colums)});") + return results diff --git a/pilot/connections/rdbms/tests/__init__.py b/pilot/connections/rdbms/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/connections/rdbms/tests/test_conn_sqlite.py b/pilot/connections/rdbms/tests/test_conn_sqlite.py new file mode 100644 index 000000000..efe4ddf76 --- /dev/null +++ b/pilot/connections/rdbms/tests/test_conn_sqlite.py @@ -0,0 +1,123 @@ +""" +Run unit test with command: pytest pilot/connections/rdbms/tests/test_conn_sqlite.py +""" +import pytest +import tempfile +import os +from pilot.connections.rdbms.conn_sqlite import SQLiteConnect + + +@pytest.fixture +def db(): + temp_db_file = tempfile.NamedTemporaryFile(delete=False) + temp_db_file.close() + conn = SQLiteConnect.from_file_path(temp_db_file.name) + yield conn + os.unlink(temp_db_file.name) + + +def test_get_table_names(db): + assert list(db.get_table_names()) == [] + + +def test_get_table_info(db): + assert db.get_table_info() == "" + + +def test_get_table_info_with_table(db): + db.run(db.session, "CREATE TABLE test (id INTEGER);") + print(db._sync_tables_from_db()) + table_info = db.get_table_info() + assert "CREATE TABLE test" in table_info + + +def test_run_sql(db): + result = db.run(db.session, "CREATE TABLE test (id INTEGER);") + assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk") + + +def test_run_no_throw(db): + assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:") + + +def test_get_indexes(db): + db.run(db.session, "CREATE TABLE test (name TEXT);") + db.run(db.session, "CREATE INDEX idx_name ON test(name);") + assert db.get_indexes("test") == [("idx_name", "c")] + + +def test_get_indexes_empty(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert db.get_indexes("test") == [] + + +def test_get_show_create_table(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert ( + db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)" + ) + + +def test_get_fields(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)] + + +def test_get_charset(db): + assert db.get_charset() == "UTF-8" + + +def test_get_collation(db): + assert db.get_collation() == "UTF-8" + + +def test_table_simple_info(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert db.table_simple_info() == ["test(id);"] + + +def test_get_table_info_no_throw(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert db.get_table_info_no_throw("xxxx_table").startswith("Error:") + + +def test_query_ex(db): + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run(db.session, "insert into test(id) values (1)") + db.run(db.session, "insert into test(id) values (2)") + field_names, result = db.query_ex(db.session, "select * from test") + assert field_names == ["id"] + assert result == [(1,), (2,)] + + field_names, result = db.query_ex(db.session, "select * from test", fetch="one") + assert field_names == ["id"] + assert result == [(1,)] + + +def test_convert_sql_write_to_select(db): + # TODO + pass + + +def test_get_grants(db): + assert db.get_grants() == [] + + +def test_get_users(db): + assert db.get_users() == [] + + +def test_get_table_comments(db): + assert db.get_table_comments() == [] + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + assert db.get_table_comments() == [ + ("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)") + ] + + +def test_get_database_list(db): + db.get_database_list() == [] + + +def test_get_database_names(db): + db.get_database_names() == []