mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
docs: add dbgpt_hub usage documents (#955)
This commit is contained in:
parent
ba8fa8774d
commit
aec124a5f1
Binary file not shown.
Before Width: | Height: | Size: 214 KiB After Width: | Height: | Size: 216 KiB |
44
dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
Normal file
44
dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_db_file.close()
|
||||
conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db")
|
||||
yield conn
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
assert db.get_users() == []
|
||||
|
||||
|
||||
def test_get_table_names(db):
|
||||
assert list(db.get_table_names()) == []
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
assert db.get_users() == []
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
assert db.get_charset() == "UTF-8"
|
||||
|
||||
|
||||
def test_get_table_comments(db):
|
||||
assert db.get_table_comments("test") == []
|
||||
|
||||
|
||||
def test_table_simple_info(db):
|
||||
assert db.table_simple_info() == []
|
||||
|
||||
|
||||
def test_execute(db):
|
||||
assert list(db.run("SELECT 42")[0]) == ["42"]
|
@ -1,4 +1,4 @@
|
||||
# DB-GPT Website
|
||||
# DB-GPT documentation
|
||||
|
||||
## Quick Start
|
||||
|
||||
@ -17,9 +17,3 @@ yarn start
|
||||
|
||||
The default service starts on port `3000`, visit `localhost:3000`
|
||||
|
||||
## Docker development
|
||||
|
||||
```commandline
|
||||
docker build -t dbgptweb .
|
||||
docker run --restart=unless-stopped -d -p 3000:3000 dbgptweb
|
||||
```
|
||||
|
118
docs/docs/application/fine_tuning_manual/dbgpt_hub.md
Normal file
118
docs/docs/application/fine_tuning_manual/dbgpt_hub.md
Normal file
@ -0,0 +1,118 @@
|
||||
# Fine-Tuning use dbgpt_hub
|
||||
|
||||
The DB-GPT-Hub project has released a pip package to lower the threshold for Text2SQL training. In addition to fine-tuning through the scripts provided in the warehouse, you can alse use the Python package we provide
|
||||
for fine-tuning.
|
||||
|
||||
## Install
|
||||
```
|
||||
pip install dbgpt_hub
|
||||
```
|
||||
|
||||
## Show Baseline
|
||||
```python
|
||||
from dbgpt_hub.baseline import show_scores
|
||||
show_scores()
|
||||
```
|
||||
<p align="left">
|
||||
<img src={'/img/ft/baseline.png'} width="720px" />
|
||||
</p>
|
||||
|
||||
## Fine-tuning
|
||||
|
||||
```python
|
||||
from dbgpt_hub.data_process import preprocess_sft_data
|
||||
from dbgpt_hub.train import train_sft
|
||||
from dbgpt_hub.predict import start_predict
|
||||
from dbgpt_hub.eval import start_evaluate
|
||||
```
|
||||
|
||||
|
||||
Preprocessing data into fine-tuned data format.
|
||||
```
|
||||
data_folder = "dbgpt_hub/data"
|
||||
data_info = [
|
||||
{
|
||||
"data_source": "spider",
|
||||
"train_file": ["train_spider.json", "train_others.json"],
|
||||
"dev_file": ["dev.json"],
|
||||
"tables_file": "tables.json",
|
||||
"db_id_name": "db_id",
|
||||
"is_multiple_turn": False,
|
||||
"train_output": "spider_train.json",
|
||||
"dev_output": "spider_dev.json",
|
||||
}
|
||||
]
|
||||
|
||||
preprocess_sft_data(
|
||||
data_folder = data_folder,
|
||||
data_info = data_info
|
||||
)
|
||||
```
|
||||
|
||||
Fine-tune the basic model and generate model weights
|
||||
```
|
||||
train_args = {
|
||||
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
|
||||
"do_train": True,
|
||||
"dataset": "example_text2sql_train",
|
||||
"max_source_length": 2048,
|
||||
"max_target_length": 512,
|
||||
"finetuning_type": "lora",
|
||||
"lora_target": "q_proj,v_proj",
|
||||
"template": "llama2",
|
||||
"lora_rank": 64,
|
||||
"lora_alpha": 32,
|
||||
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
|
||||
"overwrite_cache": True,
|
||||
"overwrite_output_dir": True,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 16,
|
||||
"lr_scheduler_type": "cosine_with_restarts",
|
||||
"logging_steps": 50,
|
||||
"save_steps": 2000,
|
||||
"learning_rate": 2e-4,
|
||||
"num_train_epochs": 8,
|
||||
"plot_loss": True,
|
||||
"bf16": True,
|
||||
}
|
||||
|
||||
start_sft(train_args)
|
||||
|
||||
```
|
||||
|
||||
Predictive model output results
|
||||
```
|
||||
predict_args = {
|
||||
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
|
||||
"template": "llama2",
|
||||
"finetuning_type": "lora",
|
||||
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
|
||||
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
|
||||
"predict_out_dir": "dbgpt_hub/output/",
|
||||
"predicted_out_filename": "pred_sql.sql",
|
||||
}
|
||||
start_predict(predict_args)
|
||||
|
||||
```
|
||||
|
||||
Evaluate the accuracy of the output results on the test datasets
|
||||
|
||||
```
|
||||
evaluate_args = {
|
||||
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
|
||||
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
|
||||
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
|
||||
"db": "./dbgpt_hub/data/spider/database",
|
||||
"table": "./dbgpt_hub/data/eval_data/tables.json",
|
||||
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
|
||||
"etype": "exec",
|
||||
"plug_value": True,
|
||||
"keep_distict": False,
|
||||
"progress_bar_for_each_datapoint": False,
|
||||
"natsql": False,
|
||||
}
|
||||
start_evaluate(evaluate_args)
|
||||
```
|
||||
|
||||
|
||||
|
@ -1,55 +0,0 @@
|
||||
ChatData & ChatDB
|
||||
==================================
|
||||
ChatData generates SQL from natural language and executes it. ChatDB involves conversing with metadata from the
|
||||
Database, including metadata about databases, tables, and
|
||||
fields.
|
||||
|
||||
### 1.Choose Datasource
|
||||
|
||||
If you are using DB-GPT for the first time, you need to add a data source and set the relevant connection information
|
||||
for the data source.
|
||||
|
||||
```{tip}
|
||||
there are some example data in DB-GPT-NEW/DB-GPT/docker/examples
|
||||
|
||||
you can execute sql script to generate data.
|
||||
```
|
||||
|
||||
#### 1.1 Datasource management
|
||||
|
||||

|
||||
|
||||
#### 1.2 Connection management
|
||||
|
||||

|
||||
|
||||
#### 1.3 Add Datasource
|
||||
|
||||

|
||||
|
||||
```{note}
|
||||
now DB-GPT support Datasource Type
|
||||
|
||||
* Mysql
|
||||
* Sqlite
|
||||
* DuckDB
|
||||
* Clickhouse
|
||||
* Mssql
|
||||
```
|
||||
|
||||
### 2.ChatData
|
||||
##### Preview Mode
|
||||
After successfully setting up the data source, you can start conversing with the database. You can ask it to generate
|
||||
SQL for you or inquire about relevant information on the database's metadata.
|
||||

|
||||
|
||||
##### Editor Mode
|
||||
In Editor Mode, you can edit your sql and execute it.
|
||||

|
||||
|
||||
|
||||
### 3.ChatDB
|
||||
|
||||

|
||||
|
||||
|
@ -82,4 +82,10 @@ Connect various data sources
|
||||
Observing & monitoring
|
||||
|
||||
- [Evaluation](/docs/modules/eval)
|
||||
Evaluate framework performance and accuracy
|
||||
Evaluate framework performance and accuracy
|
||||
|
||||
## Community
|
||||
If you encounter any problems during the process, you can submit an [issue](https://github.com/eosphoros-ai/DB-GPT/issues) and communicate with us.
|
||||
|
||||
We welcome [discussions](https://github.com/orgs/eosphoros-ai/discussions) in the community
|
||||
|
||||
|
@ -161,6 +161,10 @@ const sidebars = {
|
||||
type: 'doc',
|
||||
id: 'application/fine_tuning_manual/text_to_sql',
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
id: 'application/fine_tuning_manual/dbgpt_hub',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
@ -224,10 +228,6 @@ const sidebars = {
|
||||
type: 'doc',
|
||||
id: 'faq/kbqa',
|
||||
}
|
||||
,{
|
||||
type: 'doc',
|
||||
id: 'faq/chatdata',
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
|
BIN
docs/static/img/ft/baseline.png
vendored
Normal file
BIN
docs/static/img/ft/baseline.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 316 KiB |
0
tests/intetration_tests/datasource/__init__.py
Normal file
0
tests/intetration_tests/datasource/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = ClickhouseConnect.from_uri_db("localhost", 8123, "default", "", "default")
|
||||
yield conn
|
12
tests/intetration_tests/datasource/test_conn_doris.py
Normal file
12
tests/intetration_tests/datasource/test_conn_doris.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dbgpt.datasource.rdbms.conn_doris import DorisConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = DorisConnect.from_uri_db("localhost", 9030, "root", "", "test")
|
||||
yield conn
|
91
tests/intetration_tests/datasource/test_conn_mysql.py
Normal file
91
tests/intetration_tests/datasource/test_conn_mysql.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||
Enter password:
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 2
|
||||
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
|
||||
> create database test;
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect
|
||||
|
||||
_create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS `test` (
|
||||
`id` int(11) DEFAULT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = MySQLConnect.from_uri_db(
|
||||
"localhost",
|
||||
3307,
|
||||
"root",
|
||||
"******",
|
||||
"test",
|
||||
engine_args={"connect_args": {"charset": "utf8mb4"}},
|
||||
)
|
||||
yield conn
|
||||
|
||||
|
||||
def test_get_usable_table_names(db):
|
||||
db.run(_create_table_sql)
|
||||
print(db._sync_tables_from_db())
|
||||
assert list(db.get_usable_table_names()) == []
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
assert "CREATE TABLE test" in db.get_table_info()
|
||||
|
||||
|
||||
def test_get_table_info_with_table(db):
|
||||
db.run(_create_table_sql)
|
||||
print(db._sync_tables_from_db())
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE test" in table_info
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw("this is a error sql").startswith("Error:")
|
||||
|
||||
|
||||
def test_get_index_empty(db):
|
||||
db.run(_create_table_sql)
|
||||
assert db.get_indexes("test") == []
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
db.run(_create_table_sql)
|
||||
assert list(db.get_fields("test")[0])[0] == "id"
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
assert db.get_charset() == "utf8mb4" or db.get_charset() == "latin1"
|
||||
|
||||
|
||||
def test_get_collation(db):
|
||||
assert (
|
||||
db.get_collation() == "utf8mb4_general_ci"
|
||||
or db.get_collation() == "latin1_swedish_ci"
|
||||
)
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
assert ("root", "%") in db.get_users()
|
||||
|
||||
|
||||
def test_get_database_lists(db):
|
||||
assert db.get_database_list() == ["test"]
|
53
tests/intetration_tests/datasource/test_conn_starrocks.py
Normal file
53
tests/intetration_tests/datasource/test_conn_starrocks.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
|
||||
|
||||
docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
|
||||
|
||||
mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 184
|
||||
Server version: 5.1.0 3.1.5-5d8438a
|
||||
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
|
||||
> create database test;
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = StarRocksConnect.from_uri_db("localhost", 9030, "root", "", "test")
|
||||
yield conn
|
||||
|
||||
|
||||
def test_get_table_names(db):
|
||||
assert list(db.get_table_names()) == []
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
assert db.get_table_info() == ""
|
||||
|
||||
|
||||
def test_get_table_info_with_table(db):
|
||||
db.run("create table test(id int)")
|
||||
print(db._sync_tables_from_db())
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE test" in table_info
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw("this is a error sql").startswith("Error:")
|
||||
|
||||
|
||||
def test_get_index_empty(db):
|
||||
db.run("create table if not exists test(id int)")
|
||||
assert db.get_indexes("test") == []
|
Loading…
Reference in New Issue
Block a user