feat: Support SQLite connection (#438)

Close #423 

1. New `SQLiteConnect`.
2. Knowledge QA support SQLite .
3. Build docker image with SQLite.
This commit is contained in:
Aries-ckt 2023-08-12 12:03:13 +08:00 committed by GitHub
commit 1ba9c51c91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 655 additions and 92 deletions

View File

@ -54,11 +54,16 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
#*******************************************************************# #*******************************************************************#
#** DATABASE SETTINGS **# #** DATABASE SETTINGS **#
#*******************************************************************# #*******************************************************************#
LOCAL_DB_USER=root ### SQLite database (Current default database)
LOCAL_DB_PASSWORD=aa12345678 LOCAL_DB_PATH=data/default_sqlite.db
LOCAL_DB_HOST=127.0.0.1 LOCAL_DB_TYPE=sqlite
LOCAL_DB_PORT=3306
### MYSQL database
# LOCAL_DB_TYPE=mysql
# LOCAL_DB_USER=root
# LOCAL_DB_PASSWORD=aa12345678
# LOCAL_DB_HOST=127.0.0.1
# LOCAL_DB_PORT=3306
### MILVUS ### MILVUS
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530) ## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)

1
.gitignore vendored
View File

@ -145,6 +145,7 @@ nltk_data
.vectordb .vectordb
pilot/data/ pilot/data/
pilot/nltk_data pilot/nltk_data
pilot/mock_datas/db-gpt-test.db.wal
logswebserver.log.* logswebserver.log.*
.history/* .history/*

View File

@ -18,13 +18,13 @@ services:
networks: networks:
- dbgptnet - dbgptnet
webserver: webserver:
image: db-gpt:latest image: eosphorosai/dbgpt-allinone:latest
command: python3 pilot/server/dbgpt_server.py command: python3 pilot/server/dbgpt_server.py
environment: environment:
- LOCAL_DB_HOST=db - LOCAL_DB_HOST=db
- LOCAL_DB_PASSWORD=aa123456 - LOCAL_DB_PASSWORD=aa123456
- ALLOWLISTED_PLUGINS=db_dashboard - ALLOWLISTED_PLUGINS=db_dashboard
- LLM_MODEL=vicuna-13b - LLM_MODEL=vicuna-13b-v1.5
depends_on: depends_on:
- db - db
volumes: volumes:

View File

@ -1,4 +1,4 @@
ARG BASE_IMAGE="db-gpt:latest" ARG BASE_IMAGE="eosphorosai/dbgpt:latest"
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
@ -25,6 +25,6 @@ ENV LOCAL_DB_PASSWORD="$MYSQL_ROOT_PASSWORD"
RUN cp /app/assets/schema/knowledge_management.sql /docker-entrypoint-initdb.d/ RUN cp /app/assets/schema/knowledge_management.sql /docker-entrypoint-initdb.d/
COPY docker/allinone/allinone-entrypoint.sh /usr/local/bin/allinone-entrypoint.sh COPY docker/allinone/allinone-entrypoint.sh /usr/local/bin/allinone-entrypoint.sh
COPY docker/examples/sqls/ /docker-entrypoint-initdb.d/ COPY docker/examples/sqls/*_mysql.sql /docker-entrypoint-initdb.d/
ENTRYPOINT ["/usr/local/bin/allinone-entrypoint.sh"] ENTRYPOINT ["/usr/local/bin/allinone-entrypoint.sh"]

View File

@ -4,6 +4,6 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")" cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd) WORK_DIR=$(pwd)
IMAGE_NAME="db-gpt-allinone" IMAGE_NAME="eosphorosai/dbgpt-allinone"
docker build -f Dockerfile -t $IMAGE_NAME $WORK_DIR/../../ docker build -f Dockerfile -t $IMAGE_NAME $WORK_DIR/../../

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
docker run --gpus "device=0" -d -p 3306:3306 \ docker run --gpus all -d -p 3306:3306 \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \ -e LOCAL_DB_PASSWORD=aa123456 \
@ -9,5 +9,5 @@ docker run --gpus "device=0" -d -p 3306:3306 \
-e LANGUAGE=zh \ -e LANGUAGE=zh \
-v /data:/data \ -v /data:/data \
-v /data/models:/app/models \ -v /data/models:/app/models \
--name db-gpt-allinone \ --name dbgpt-allinone \
db-gpt-allinone eosphorosai/dbgpt-allinone

View File

@ -4,7 +4,7 @@
PROXY_API_KEY="$PROXY_API_KEY" PROXY_API_KEY="$PROXY_API_KEY"
PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}" PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}"
docker run --gpus "device=0" -d -p 3306:3306 \ docker run --gpus all -d -p 3306:3306 \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \ -e LOCAL_DB_PASSWORD=aa123456 \
@ -15,5 +15,5 @@ docker run --gpus "device=0" -d -p 3306:3306 \
-e LANGUAGE=zh \ -e LANGUAGE=zh \
-v /data:/data \ -v /data:/data \
-v /data/models:/app/models \ -v /data/models:/app/models \
--name db-gpt-allinone \ --name dbgpt-allinone \
db-gpt-allinone eosphorosai/dbgpt-allinone

View File

@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
ARG BASE_IMAGE ARG BASE_IMAGE
RUN apt-get update && apt-get install -y git python3 pip wget \ RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \
&& apt-get clean && apt-get clean
ARG BUILD_LOCAL_CODE="false" ARG BUILD_LOCAL_CODE="false"
@ -45,4 +45,14 @@ RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \
else rm -rf /tmp/app; \ else rm -rf /tmp/app; \
fi;) fi;)
EXPOSE 5000 ARG LOAD_EXAMPLES="true"
RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
then mkdir -p /app/pilot/data && sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/case_1_student_manager_sqlite.sql \
&& sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/case_2_ecom_sqlite.sql \
&& sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \
fi;)
EXPOSE 5000
CMD ["python3", "pilot/server/dbgpt_server.py"]

View File

@ -5,12 +5,13 @@ cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd) WORK_DIR=$(pwd)
BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04" BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
IMAGE_NAME="db-gpt" IMAGE_NAME="eosphorosai/dbgpt"
# zh: https://pypi.tuna.tsinghua.edu.cn/simple # zh: https://pypi.tuna.tsinghua.edu.cn/simple
PIP_INDEX_URL="https://pypi.org/simple" PIP_INDEX_URL="https://pypi.org/simple"
# en or zh # en or zh
LANGUAGE="en" LANGUAGE="en"
BUILD_LOCAL_CODE="false" BUILD_LOCAL_CODE="false"
LOAD_EXAMPLES="true"
usage () { usage () {
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04] [--image-name db-gpt]" echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04] [--image-name db-gpt]"
@ -19,6 +20,7 @@ usage () {
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
echo " [--language en or zh] You language, default: en" echo " [--language en or zh] You language, default: en"
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false" echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
echo " [--load-examples true or false] Whether to load examples to default database default: true"
echo " [-h|--help] Usage message" echo " [-h|--help] Usage message"
} }
@ -50,6 +52,11 @@ while [[ $# -gt 0 ]]; do
shift shift
shift shift
;; ;;
--load-examples)
LOAD_EXAMPLES="$2"
shift
shift
;;
-h|--help) -h|--help)
help="true" help="true"
shift shift
@ -71,5 +78,6 @@ docker build \
--build-arg PIP_INDEX_URL=$PIP_INDEX_URL \ --build-arg PIP_INDEX_URL=$PIP_INDEX_URL \
--build-arg LANGUAGE=$LANGUAGE \ --build-arg LANGUAGE=$LANGUAGE \
--build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \ --build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \
--build-arg LOAD_EXAMPLES=$LOAD_EXAMPLES \
-f Dockerfile \ -f Dockerfile \
-t $IMAGE_NAME $WORK_DIR/../../ -t $IMAGE_NAME $WORK_DIR/../../

12
docker/base/run_sqlite.sh Executable file
View File

@ -0,0 +1,12 @@
#!/bin/bash
docker run --gpus all -d \
-p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \
-e LLM_MODEL=vicuna-13b-v1.5 \
-e LANGUAGE=zh \
-v /data:/data \
-v /data/models:/app/models \
--name dbgpt \
eosphorosai/dbgpt

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Your api key
PROXY_API_KEY="$PROXY_API_KEY"
PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}"
docker run --gpus all -d \
-p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \
-e LLM_MODEL=proxyllm \
-e PROXY_API_KEY=$PROXY_API_KEY \
-e PROXY_SERVER_URL=$PROXY_SERVER_URL \
-e LANGUAGE=zh \
-v /data:/data \
-v /data/models:/app/models \
--name dbgpt \
eosphorosai/dbgpt

View File

@ -0,0 +1,59 @@
CREATE TABLE students (
student_id INTEGER PRIMARY KEY,
student_name VARCHAR(100),
major VARCHAR(100),
year_of_enrollment INTEGER,
student_age INTEGER
);
CREATE TABLE courses (
course_id INTEGER PRIMARY KEY,
course_name VARCHAR(100),
credit REAL
);
CREATE TABLE scores (
student_id INTEGER,
course_id INTEGER,
score INTEGER,
semester VARCHAR(50),
PRIMARY KEY (student_id, course_id),
FOREIGN KEY (student_id) REFERENCES students(student_id),
FOREIGN KEY (course_id) REFERENCES courses(course_id)
);
INSERT INTO students (student_id, student_name, major, year_of_enrollment, student_age) VALUES
(1, '张三', '计算机科学', 2020, 20),
(2, '李四', '计算机科学', 2021, 19),
(3, '王五', '物理学', 2020, 21),
(4, '赵六', '数学', 2021, 19),
(5, '周七', '计算机科学', 2022, 18),
(6, '吴八', '物理学', 2020, 21),
(7, '郑九', '数学', 2021, 19),
(8, '孙十', '计算机科学', 2022, 18),
(9, '刘十一', '物理学', 2020, 21),
(10, '陈十二', '数学', 2021, 19);
INSERT INTO courses (course_id, course_name, credit) VALUES
(1, '计算机基础', 3),
(2, '数据结构', 4),
(3, '高等物理', 3),
(4, '线性代数', 4),
(5, '微积分', 5),
(6, '编程语言', 4),
(7, '量子力学', 3),
(8, '概率论', 4),
(9, '数据库系统', 4),
(10, '计算机网络', 4);
INSERT INTO scores (student_id, course_id, score, semester) VALUES
(1, 1, 90, '2020年秋季'),
(1, 2, 85, '2021年春季'),
(2, 1, 88, '2021年秋季'),
(2, 2, 90, '2022年春季'),
(3, 3, 92, '2020年秋季'),
(3, 4, 85, '2021年春季'),
(4, 3, 88, '2021年秋季'),
(4, 4, 86, '2022年春季'),
(5, 1, 90, '2022年秋季'),
(5, 2, 87, '2023年春季');

View File

@ -0,0 +1,59 @@
CREATE TABLE users (
user_id INTEGER PRIMARY KEY,
user_name VARCHAR(100),
user_email VARCHAR(100),
registration_date DATE,
user_country VARCHAR(100)
);
CREATE TABLE products (
product_id INTEGER PRIMARY KEY,
product_name VARCHAR(100),
product_price REAL
);
CREATE TABLE orders (
order_id INTEGER PRIMARY KEY,
user_id INTEGER,
product_id INTEGER,
quantity INTEGER,
order_date DATE,
FOREIGN KEY (user_id) REFERENCES users(user_id),
FOREIGN KEY (product_id) REFERENCES products(product_id)
);
INSERT INTO users (user_id, user_name, user_email, registration_date, user_country) VALUES
(1, 'John', 'john@gmail.com', '2020-01-01', 'USA'),
(2, 'Mary', 'mary@gmail.com', '2021-01-01', 'UK'),
(3, 'Bob', 'bob@gmail.com', '2020-01-01', 'USA'),
(4, 'Alice', 'alice@gmail.com', '2021-01-01', 'UK'),
(5, 'Charlie', 'charlie@gmail.com', '2020-01-01', 'USA'),
(6, 'David', 'david@gmail.com', '2021-01-01', 'UK'),
(7, 'Eve', 'eve@gmail.com', '2020-01-01', 'USA'),
(8, 'Frank', 'frank@gmail.com', '2021-01-01', 'UK'),
(9, 'Grace', 'grace@gmail.com', '2020-01-01', 'USA'),
(10, 'Helen', 'helen@gmail.com', '2021-01-01', 'UK');
INSERT INTO products (product_id, product_name, product_price) VALUES
(1, 'iPhone', 699),
(2, 'Samsung Galaxy', 599),
(3, 'iPad', 329),
(4, 'Macbook', 1299),
(5, 'Apple Watch', 399),
(6, 'AirPods', 159),
(7, 'Echo', 99),
(8, 'Kindle', 89),
(9, 'Fire TV Stick', 39),
(10, 'Echo Dot', 49);
INSERT INTO orders (order_id, user_id, product_id, quantity, order_date) VALUES
(1, 1, 1, 1, '2022-01-01'),
(2, 1, 2, 1, '2022-02-01'),
(3, 2, 3, 2, '2022-03-01'),
(4, 2, 4, 1, '2022-04-01'),
(5, 3, 5, 2, '2022-05-01'),
(6, 3, 6, 3, '2022-06-01'),
(7, 4, 7, 2, '2022-07-01'),
(8, 4, 8, 1, '2022-08-01'),
(9, 5, 9, 2, '2022-09-01'),
(10, 5, 10, 3, '2022-10-01');

View File

@ -0,0 +1,17 @@
CREATE TABLE test_cases (
case_id INTEGER PRIMARY KEY AUTOINCREMENT,
scenario_name VARCHAR(100),
scenario_description TEXT,
test_question VARCHAR(500),
expected_sql TEXT,
correct_output TEXT
);
INSERT INTO test_cases (scenario_name, scenario_description, test_question, expected_sql, correct_output) VALUES
('学校管理系统', '测试SQL助手的联合查询条件查询和排序功能', '查询所有学生的姓名,专业和成绩,按成绩降序排序', 'SELECT students.student_name, students.major, scores.score FROM students JOIN scores ON students.student_id = scores.student_id ORDER BY scores.score DESC;', '返回所有学生的姓名,专业和成绩,按成绩降序排序的结果'),
('学校管理系统', '测试SQL助手的联合查询条件查询和排序功能', '查询计算机科学专业的学生的平均成绩', 'SELECT AVG(scores.score) as avg_score FROM students JOIN scores ON students.student_id = scores.student_id WHERE students.major = ''计算机科学'';', '返回计算机科学专业学生的平均成绩'),
('学校管理系统', '测试SQL助手的联合查询条件查询和排序功能', '查询哪些学生在2023年秋季学期的课程学分总和超过15', 'SELECT students.student_name FROM students JOIN scores ON students.student_id = scores.student_id JOIN courses ON scores.course_id = courses.course_id WHERE scores.semester = ''2023年秋季'' GROUP BY students.student_id HAVING SUM(courses.credit) > 15;', '返回在2023年秋季学期的课程学分总和超过15的学生的姓名'),
('电商系统', '测试SQL助手的数据聚合和分组功能', '查询每个用户的总订单数量', 'SELECT users.user_name, COUNT(orders.order_id) as order_count FROM users JOIN orders ON users.user_id = orders.user_id GROUP BY users.user_id;', '返回每个用户的总订单数量'),
('电商系统', '测试SQL助手的数据聚合和分组功能', '查询每种商品的总销售额', 'SELECT products.product_name, SUM(products.product_price * orders.quantity) as total_sales FROM products JOIN orders ON products.product_id = orders.product_id GROUP BY products.product_id;', '返回每种商品的总销售额'),
('电商系统', '测试SQL助手的数据聚合和分组功能', '查询2023年最受欢迎的商品订单数量最多的商品', 'SELECT products.product_name FROM products JOIN orders ON products.product_id = orders.product_id WHERE YEAR(orders.order_date) = 2023 GROUP BY products.product_id ORDER BY COUNT(orders.order_id) DESC LIMIT 1;', '返回2023年最受欢迎的商品订单数量最多的商品的名称');

View File

@ -17,16 +17,7 @@ As our project has the ability to achieve ChatGPT performance of over 85%, there
### 2. Install ### 2. Install
1.This project relies on a local MySQL database service, which you need to install locally. We recommend using Docker for installation. We use [Chroma embedding database](https://github.com/chroma-core/chroma) as the default for our vector database and use SQLite as the default for our database, so there is no need for special installation. If you choose to connect to other databases, you can follow our tutorial for installation and configuration.
```bash
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
```
2. prepare server sql script
```bash
$ mysql -h127.0.0.1 -uroot -paa12345678 < ./assets/schema/knowledge_management.sql
```
We use [Chroma embedding database](https://github.com/chroma-core/chroma) as the default for our vector database, so there is no need for special installation. If you choose to connect to other databases, you can follow our tutorial for installation and configuration.
For the entire installation process of DB-GPT, we use the miniconda3 virtual environment. Create a virtual environment and install the Python dependencies. For the entire installation process of DB-GPT, we use the miniconda3 virtual environment. Create a virtual environment and install the Python dependencies.
```bash ```bash
@ -38,7 +29,6 @@ pip install -r requirements.txt
Before use DB-GPT Knowledge Management Before use DB-GPT Knowledge Management
```bash ```bash
python -m spacy download zh_core_web_sm python -m spacy download zh_core_web_sm
``` ```
Once the environment is installed, we have to create a new folder "models" in the DB-GPT project, and then we can put all the models downloaded from huggingface in this directory Once the environment is installed, we have to create a new folder "models" in the DB-GPT project, and then we can put all the models downloaded from huggingface in this directory
@ -100,16 +90,18 @@ $ bash docker/build_all_images.sh
Review images by listing them: Review images by listing them:
```bash ```bash
$ docker images|grep db-gpt $ docker images|grep "eosphorosai/dbgpt"
``` ```
Output should look something like the following: Output should look something like the following:
``` ```
db-gpt-allinone latest e1ffd20b85ac 45 minutes ago 14.5GB eosphorosai/dbgpt-allinone latest 349d49726588 27 seconds ago 15.1GB
db-gpt latest e36fb0cca5d9 3 hours ago 14GB eosphorosai/dbgpt latest eb3cdc5b4ead About a minute ago 14.5GB
``` ```
`eosphorosai/dbgpt` is the base image, which contains the project's base dependencies and a sqlite database. `eosphorosai/dbgpt-allinone` build from `eosphorosai/dbgpt`, which contains a mysql database.
You can pass some parameters to docker/build_all_images.sh. You can pass some parameters to docker/build_all_images.sh.
```bash ```bash
$ bash docker/build_all_images.sh \ $ bash docker/build_all_images.sh \
@ -122,19 +114,18 @@ You can execute the command `bash docker/build_all_images.sh --help` to see more
#### 4.2. Run all in one docker container #### 4.2. Run all in one docker container
**Run with local model** **Run with local model and SQLite database**
```bash ```bash
$ docker run --gpus "device=0" -d -p 3306:3306 \ $ docker run --gpus all -d \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PASSWORD=aa123456 \ -e LOCAL_DB_PATH=data/default_sqlite.db \
-e MYSQL_ROOT_PASSWORD=aa123456 \
-e LLM_MODEL=vicuna-13b \ -e LLM_MODEL=vicuna-13b \
-e LANGUAGE=zh \ -e LANGUAGE=zh \
-v /data/models:/app/models \ -v /data/models:/app/models \
--name db-gpt-allinone \ --name dbgpt \
db-gpt-allinone eosphorosai/dbgpt
``` ```
Open http://localhost:5000 with your browser to see the product. Open http://localhost:5000 with your browser to see the product.
@ -146,7 +137,22 @@ Open http://localhost:5000 with your browser to see the product.
You can see log with command: You can see log with command:
```bash ```bash
$ docker logs db-gpt-allinone -f $ docker logs dbgpt -f
```
**Run with local model and MySQL database**
```bash
$ docker run --gpus all -d -p 3306:3306 \
-p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \
-e MYSQL_ROOT_PASSWORD=aa123456 \
-e LLM_MODEL=vicuna-13b \
-e LANGUAGE=zh \
-v /data/models:/app/models \
--name dbgpt \
eosphorosai/dbgpt-allinone
``` ```
**Run with openai interface** **Run with openai interface**
@ -154,7 +160,7 @@ $ docker logs db-gpt-allinone -f
```bash ```bash
$ PROXY_API_KEY="You api key" $ PROXY_API_KEY="You api key"
$ PROXY_SERVER_URL="https://api.openai.com/v1/chat/completions" $ PROXY_SERVER_URL="https://api.openai.com/v1/chat/completions"
$ docker run --gpus "device=0" -d -p 3306:3306 \ $ docker run --gpus all -d -p 3306:3306 \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \ -e LOCAL_DB_PASSWORD=aa123456 \
@ -164,8 +170,8 @@ $ docker run --gpus "device=0" -d -p 3306:3306 \
-e PROXY_SERVER_URL=$PROXY_SERVER_URL \ -e PROXY_SERVER_URL=$PROXY_SERVER_URL \
-e LANGUAGE=zh \ -e LANGUAGE=zh \
-v /data/models/text2vec-large-chinese:/app/models/text2vec-large-chinese \ -v /data/models/text2vec-large-chinese:/app/models/text2vec-large-chinese \
--name db-gpt-allinone \ --name dbgpt \
db-gpt-allinone eosphorosai/dbgpt-allinone
``` ```
- `-e LLM_MODEL=proxyllm`, means we use proxy llm(openai interface, fastchat interface...) - `-e LLM_MODEL=proxyllm`, means we use proxy llm(openai interface, fastchat interface...)

View File

@ -1,5 +1,6 @@
from enum import auto, Enum from enum import auto, Enum
from typing import List, Any from typing import List, Any
import os
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
@ -24,6 +25,7 @@ class DBType(Enum):
Mysql = DbInfo("mysql") Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase") OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True) DuckDb = DbInfo("duckdb", True)
SQLite = DbInfo("sqlite", True)
Oracle = DbInfo("oracle") Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql") MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql") Postgresql = DbInfo("postgresql")
@ -40,3 +42,12 @@ class DBType(Enum):
if item.value() == db_type: if item.value() == db_type:
return item return item
return None return None
@staticmethod
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
"""Parse out the database name of the embedded database from the file path"""
base_name = os.path.basename(local_db_path)
db_name = os.path.splitext(base_name)[0]
if "." in db_name:
db_name = os.path.splitext(db_name)[0]
return db_type + "_" + db_name

View File

@ -121,8 +121,12 @@ class Config(metaclass=Singleton):
) )
### default Local database connection configuration ### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1"
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME") self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")

View File

View File

@ -8,6 +8,7 @@ from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.singleton import Singleton from pilot.singleton import Singleton
@ -89,12 +90,28 @@ class ConnectManager:
"", "",
) )
if CFG.LOCAL_DB_PATH: if CFG.LOCAL_DB_PATH:
# default file db is duckdb db_name = CFG.LOCAL_DB_NAME
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH) db_type = CFG.LOCAL_DB_TYPE
db_path = CFG.LOCAL_DB_PATH
if not db_type:
# Default file database type
db_type = DBType.DuckDb.value()
if not db_name:
db_type, db_name = self._parse_file_db_info(db_type, db_path)
if db_name: if db_name:
self.storage.add_file_db( print(
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
) )
self.storage.add_file_db(db_name, db_type, db_path)
def _parse_file_db_info(self, db_type: str, db_path: str):
if db_type is None or db_type == DBType.DuckDb.value():
# file db is duckdb
db_name = self.storage.get_file_db_name(db_path)
db_type = DBType.DuckDb.value()
else:
db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
return db_type, db_name
def get_connect(self, db_name): def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name) db_config = self.storage.get_db_config(db_name)

View File

@ -21,6 +21,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session
from pilot.common.schema import DBType
from pilot.connections.base import BaseConnect from pilot.connections.base import BaseConnect
from pilot.configs.config import Config from pilot.configs.config import Config
@ -78,16 +79,7 @@ class RDBMSDatabase(BaseConnect):
self._metadata = MetaData() self._metadata = MetaData()
self._metadata.reflect(bind=self._engine) self._metadata.reflect(bind=self._engine)
# including view support by adding the views as well as tables to the all self._all_tables = self._sync_tables_from_db()
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=self._engine.url.database)
+ (
self._inspector.get_view_names(schema=self._engine.url.database)
if self.view_support
else []
)
)
@classmethod @classmethod
def from_uri_db( def from_uri_db(
@ -128,6 +120,26 @@ class RDBMSDatabase(BaseConnect):
"""Return string representation of dialect to use.""" """Return string representation of dialect to use."""
return self._engine.dialect.name return self._engine.dialect.name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema
_schema = (
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=_schema)
+ (
self._inspector.get_view_names(schema=_schema)
if self.view_support
else []
)
)
return self._all_tables
def get_usable_table_names(self) -> Iterable[str]: def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available.""" """Get names of tables available."""
if self._include_tables: if self._include_tables:
@ -250,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message""" """Format the error message"""
return f"Error: {e}" return f"Error: {e}"
def __write(self, session, write_sql): def _write(self, session, write_sql):
print(f"Write[{write_sql}]") print(f"Write[{write_sql}]")
db_cache = self._engine.url.database db_cache = self._engine.url.database
result = session.execute(text(write_sql)) result = session.execute(text(write_sql))
@ -279,7 +291,7 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
elif fetch == "one": elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore result = result = [cursor.fetchone()] # type: ignore
else: else:
raise ValueError("Fetch parameter must be either 'one' or 'all'") raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys()) field_names = tuple(i[0:] for i in cursor.keys())
@ -305,11 +317,10 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
elif fetch == "one": elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore result = [cursor.fetchone()] # type: ignore
else: else:
raise ValueError("Fetch parameter must be either 'one' or 'all'") raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys()) field_names = list(i[0:] for i in cursor.keys())
result = list(result) result = list(result)
return field_names, result return field_names, result
@ -323,7 +334,7 @@ class RDBMSDatabase(BaseConnect):
if sql_type == "SELECT": if sql_type == "SELECT":
return self.__query(session, command, fetch) return self.__query(session, command, fetch)
else: else:
self.__write(session, command) self._write(session, command)
select_sql = self.convert_sql_write_to_select(command) select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}") print(f"write result query:{select_sql}")
return self.__query(session, select_sql) return self.__query(session, select_sql)
@ -332,6 +343,11 @@ class RDBMSDatabase(BaseConnect):
print(f"DDL execution determines whether to enable through configuration ") print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command)) cursor = session.execute(text(command))
session.commit() session.commit()
_show_columns_sql = (
f"PRAGMA table_info({table_name})"
if self.db_type == "sqlite"
else f"SHOW COLUMNS FROM {table_name}"
)
if cursor.returns_rows: if cursor.returns_rows:
result = cursor.fetchall() result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys()) field_names = tuple(i[0:] for i in cursor.keys())
@ -339,10 +355,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names) result.insert(0, field_names)
print("DDL Result:" + str(result)) print("DDL Result:" + str(result))
if not result: if not result:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}") return self.__query(session, _show_columns_sql)
return result return result
else: else:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}") return self.__query(session, _show_columns_sql)
def run_no_throw(self, session, command: str, fetch: str = "all") -> List: def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.

View File

@ -0,0 +1,83 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config
from pilot.common.schema import DBType
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.logs import logger
CFG = Config()
class BaseDao:
def __init__(
self, orm_base=None, database: str = None, create_not_exist_table: bool = False
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
self._database = database
self._create_not_exist_table = create_not_exist_table
self._db_engine = None
self._session = None
self._connection = None
@property
def db_engine(self):
if not self._db_engine:
# lazy loading
db_engine, connection = _get_db_engine(
self._orm_base, self._database, self._create_not_exist_table
)
self._db_engine = db_engine
self._connection = connection
return self._db_engine
@property
def Session(self):
if not self._session:
self._session = sessionmaker(bind=self.db_engine)
return self._session
def _get_db_engine(
orm_base=None, database: str = None, create_not_exist_table: bool = False
):
db_engine = None
connection: RDBMSDatabase = None
db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
if db_type is None or db_type == DBType.Mysql:
# default database
db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
else:
db_namager = CFG.LOCAL_DB_MANAGE
if not db_namager:
raise Exception(
"LOCAL_DB_MANAGE is not initialized, please check the system configuration"
)
if db_type.is_file_db():
db_path = CFG.LOCAL_DB_PATH
if db_path is None or db_path == "":
raise ValueError(
"You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file"
)
_, database = db_namager._parse_file_db_info(db_type.value(), db_path)
logger.info(
f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}"
)
logger.info(f"Get DAO database connection with database name {database}")
connection: RDBMSDatabase = db_namager.get_connect(database)
if not isinstance(connection, RDBMSDatabase):
raise ValueError(
"Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration"
)
db_engine = connection._engine
if db_type.is_file_db() and orm_base is not None and create_not_exist_table:
logger.info("Current database is file database, create not exist table")
orm_base.metadata.create_all(db_engine)
return db_engine, connection

View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Any, Iterable
from sqlalchemy import create_engine, text
from pilot.connections.rdbms.base import RDBMSDatabase
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
Args:
Usage:
"""
db_type: str = "sqlite"
db_dialect: str = "sqlite"
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
# _engine_args["echo"] = True
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
indexes = cursor.fetchall()
return [(index[1], index[3]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
cursor = self.session.execute(
text(
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
)
)
ans = cursor.fetchall()
return ans[0][0]
def get_fields(self, table_name):
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()
print(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_database_list(self):
return []
def get_database_names(self):
return []
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
view_results = self.session.execute(
"SELECT name FROM sqlite_master WHERE type='view'"
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def _write(self, session, write_sql):
print(f"Write[{write_sql}]")
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def get_table_comments(self, db_name=None):
cursor = self.session.execute(
text(
f"""
SELECT name, sql FROM sqlite_master WHERE type='table'
"""
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT name FROM sqlite_master WHERE type='table'
"""
cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall()
results = []
for row in tables_results:
table_name = row[0]
_sql = f"""
PRAGMA table_info({table_name})
"""
cursor_colums = self.session.execute(text(_sql))
colum_results = cursor_colums.fetchall()
table_colums = []
for row_col in colum_results:
field_info = list(row_col)
table_colums.append(field_info[1])
results.append(f"{table_name}({','.join(table_colums)});")
return results

View File

@ -0,0 +1,123 @@
"""
Run unit test with command: pytest pilot/connections/rdbms/tests/test_conn_sqlite.py
"""
import pytest
import tempfile
import os
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
@pytest.fixture
def db():
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
temp_db_file.close()
conn = SQLiteConnect.from_file_path(temp_db_file.name)
yield conn
os.unlink(temp_db_file.name)
def test_get_table_names(db):
assert list(db.get_table_names()) == []
def test_get_table_info(db):
assert db.get_table_info() == ""
def test_get_table_info_with_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER);")
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info
def test_run_sql(db):
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
def test_run_no_throw(db):
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
def test_get_indexes(db):
db.run(db.session, "CREATE TABLE test (name TEXT);")
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")]
def test_get_indexes_empty(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_indexes("test") == []
def test_get_show_create_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert (
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
)
def test_get_fields(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
def test_get_charset(db):
assert db.get_charset() == "UTF-8"
def test_get_collation(db):
assert db.get_collation() == "UTF-8"
def test_table_simple_info(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.table_simple_info() == ["test(id);"]
def test_get_table_info_no_throw(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
def test_query_ex(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
assert field_names == ["id"]
assert result == [(1,)]
def test_convert_sql_write_to_select(db):
# TODO
pass
def test_get_grants(db):
assert db.get_grants() == []
def test_get_users(db):
assert db.get_users() == []
def test_get_table_comments(db):
assert db.get_table_comments() == []
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_comments() == [
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
]
def test_get_database_list(db):
db.get_database_list() == []
def test_get_database_names(db):
db.get_database_names() == []

View File

@ -144,10 +144,8 @@ async def document_upload(
request = KnowledgeDocumentRequest() request = KnowledgeDocumentRequest()
request.doc_name = doc_name request.doc_name = doc_name
request.doc_type = doc_type request.doc_type = doc_type
request.content = ( request.content = os.path.join(
os.path.join( KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
) )
return Result.succ( return Result.succ(
knowledge_space_service.create_knowledge_document( knowledge_space_service.create_knowledge_document(

View File

@ -5,7 +5,7 @@ from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, f
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
@ -27,14 +27,11 @@ class DocumentChunkEntity(Base):
return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class DocumentChunkDao: class DocumentChunkDao(BaseDao):
def __init__(self): def __init__(self):
database = "knowledge_management" super().__init__(
self.db_engine = create_engine( database="knowledge_management", orm_base=Base, create_not_exist_table=True
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
) )
self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents: List): def create_documents_chunks(self, documents: List):
session = self.Session() session = self.Session()

View File

@ -4,7 +4,7 @@ from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, f
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
@ -19,7 +19,7 @@ class KnowledgeDocumentEntity(Base):
space = Column(String(100)) space = Column(String(100))
chunk_size = Column(Integer) chunk_size = Column(Integer)
status = Column(String(100)) status = Column(String(100))
last_sync = Column(String(100)) last_sync = Column(DateTime)
content = Column(Text) content = Column(Text)
result = Column(Text) result = Column(Text)
vector_ids = Column(Text) vector_ids = Column(Text)
@ -30,14 +30,11 @@ class KnowledgeDocumentEntity(Base):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao: class KnowledgeDocumentDao(BaseDao):
def __init__(self): def __init__(self):
database = "knowledge_management" super().__init__(
self.db_engine = create_engine( database="knowledge_management", orm_base=Base, create_not_exist_table=True
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
) )
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document: KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.Session()

View File

@ -2,11 +2,11 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from sqlalchemy.orm import sessionmaker from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
@ -27,14 +27,11 @@ class KnowledgeSpaceEntity(Base):
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeSpaceDao: class KnowledgeSpaceDao(BaseDao):
def __init__(self): def __init__(self):
database = "knowledge_management" super().__init__(
self.db_engine = create_engine( database="knowledge_management", orm_base=Base, create_not_exist_table=True
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
) )
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space: KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session() session = self.Session()