Merge pull request #9 from csunny/dev

connect database
This commit is contained in:
magic.chen 2023-05-03 22:18:36 +08:00 committed by GitHub
commit f50a4b05a6
18 changed files with 395 additions and 51 deletions

View File

@ -3,6 +3,8 @@ A Open Database-GPT Experiment
![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social)
DB-GPT 是一个实验性的开源应用程序它基于FastChat并使用vicuna-13b作为基础模型。此外此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。
DB-GPT is an experimental open-source application that builds upon the fastchat model and uses vicuna as its base model. Additionally, it looks like this application incorporates langchain and llama-index embedding knowledge to improve Database-QA capabilities.
@ -14,21 +16,26 @@ Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www
![](https://github.com/csunny/DB-GPT/blob/dev/asserts/演示.gif)
- SQL生成示例
首先选择对应的数据库, 然后模型即可根据对应的数据库Schema信息生成SQL
<img src="https://github.com/csunny/DB-GPT/blob/dev/asserts/sql_generate.png" width="600" margin-left="auto" margin-right="auto" >
<img src="https://github.com/csunny/DB-GPT/blob/dev/asserts/SQLGEN.png" width="600" margin-left="auto" margin-right="auto" >
- 数据库QA示例
<img src="https://github.com/csunny/DB-GPT/blob/dev/asserts/DB_QA.png" margin-left="auto" margin-right="auto" width="600">
# Install
1. Run model server
1. 基础模型下载
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)
2. Run model server
```
cd pilot/server
python vicuna_server.py
```
2. Run gradio webui
3. Run gradio webui
```
python webserver.py
```
@ -37,3 +44,5 @@ python webserver.py
- SQL-Generate
- Database-QA Based Knowledge
- SQL-diagnosis
总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。

BIN
asserts/SQLGEN.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 297 KiB

View File

@ -33,12 +33,22 @@ def knowledged_qa_demo(text_list):
def get_answer(q):
base_knowledge = """ 这是一段测试文字 """
base_knowledge = """ """
text_list = [base_knowledge]
index = knowledged_qa_demo(text_list)
response = index.query(q)
return response.response
def get_similar(q):
from pilot.vector_store.extract_tovec import knownledge_tovec
docsearch = knownledge_tovec("./datasets/plan.md")
docs = docsearch.similarity_search_with_score(q, k=1)
for doc in docs:
dc, s = doc
print(dc.page_content)
yield dc.page_content
if __name__ == "__main__":
# agent_demo()
@ -49,7 +59,7 @@ if __name__ == "__main__":
text_output = gr.TextArea()
text_button = gr.Button()
text_button.click(get_answer, inputs=text_input, outputs=text_output)
text_button.click(get_similar, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

0
pilot/chain/__init__.py Normal file
View File

0
pilot/client/__init__.py Normal file
View File

0
pilot/common/__init__.py Normal file
View File

View File

@ -25,3 +25,11 @@ vicuna_model_server = "http://192.168.31.114:8000"
# Load model config
isload_8bit = True
isdebug = False
DB_SETTINGS = {
"user": "root",
"password": "********",
"host": "localhost",
"port": 3306
}

View File

@ -1,2 +1,42 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import pymysql
class MySQLOperator:
"""Connect MySQL Database fetch MetaData For LLM Prompt """
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
def __init__(self, user, password, host="localhost", port=3306) -> None:
self.conn = pymysql.connect(
host=host,
user=user,
passwd=password,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor
)
def get_schema(self, schema_name):
with self.conn.cursor() as cursor:
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME;
"""
cursor.execute(_sql)
results = cursor.fetchall()
return results
def get_db_list(self):
with self.conn.cursor() as cursor:
_sql = """
show databases;
"""
cursor.execute(_sql)
results = cursor.fetchall()
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
return dbs

View File

@ -4,7 +4,7 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any
from pilot.configs.model_config import DB_SETTINGS
class SeparatorStyle(Enum):
@ -88,6 +88,19 @@ class Conversation:
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql_conn import MySQLOperator
mo = MySQLOperator(
**DB_SETTINGS
)
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
@ -121,7 +134,7 @@ conv_one_shot = Conversation(
sep_style=SeparatorStyle.SINGLE,
sep="###"
)
conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
@ -137,5 +150,10 @@ default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
"vicuna_v1": conv_vicuna_v1,
}
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)

0
pilot/data/__init__.py Normal file
View File

View File

185
pilot/datasets/plan.md Normal file
View File

@ -0,0 +1,185 @@
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述。用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划。
如果要分析某条 SQL 的性能问题,通常需要先查看 SQL 的执行计划,排查每一步 SQL 执行是否存在问题。所以读懂执行计划是 SQL 优化的先决条件,而了解执行计划的算子是理解 EXPLAIN 命令的关键。
OceanBase 数据库的执行计划命令有三种模式EXPLAIN BASIC、EXPLAIN 和 EXPLAIN EXTENDED。这三种模式对执行计划展现不同粒度的细节信息:
EXPLAIN BASIC 命令用于最基本的计划展示。
EXPLAIN EXTENDED 命令用于最详细的计划展示(通常在排查问题时使用这种展示模式)。
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式。
EXPLAIN 命令格式如下:
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
format_name:
{ TRADITIONAL | JSON }
explainable_stmt:
{ SELECT st
| DELETE statement
| INSERT statement
| REPLACE statement
| UPDATE statement }
EXPLAIN 命令适用于 SELECT、DELETE、INSERT、REPLACE 和 UPDATE 语句,显示优化器所提供的有关语句执行计划的信息,包括如何处理该语句,如何联接表以及以何种顺序联接表等信息。
一般来说,可以使用 EXPLAIN EXTENDED 命令,将表扫描的范围段展示出来。使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息。
FORMAT 选项可用于选择输出格式。TRADITIONAL 表示以表格格式显示输出这也是默认设置。JSON 表示以 JSON 格式显示信息。
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询。如果检查针对非分区表的查询,则不会产生错误,但 PARTIONS 列的值始终为 NULL。
对于复杂的执行计划,可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来,使得执行计划展示更方便阅读。示例如下:
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
Query OK, 0 rows affected
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
Query OK, 0 rows affected
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
*************************** 1. row ***************************
Query Plan: ==========================================================
|ID|OPERATOR |NAME |EST. ROWS|COST|
----------------------------------------------------------
|0 |PX COORDINATOR | |1 |278 |
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|2 | HASH JOIN | |1 |276 |
|3 | ├PX PARTITION ITERATOR | |1 |92 |
|4 | │ TABLE SCAN |P1 |1 |92 |
|5 | └EXCHANGE IN DISTR | |1 |184 |
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|7 | PX PARTITION ITERATOR | |1 |183 |
|8 | TABLE SCAN |P2 |1 |183 |
==========================================================
Outputs & filters:
-------------------------------------
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
3 - output([P1.C1], [P1.C2]), filter(nil)
4 - output([P1.C1], [P1.C2]), filter(nil),
access([P1.C1], [P1.C2]), partitions(p[0-1])
5 - output([P2.C2], [P2.C1]), filter(nil)
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
7 - output([P2.C1], [P2.C2]), filter(nil)
8 - output([P2.C1], [P2.C2]), filter(nil),
access([P2.C1], [P2.C2]), partitions(p[0-3])
1 row in set
## 执行计划形状与算子信息
在数据库系统中,执行计划在内部通常是以树的形式来表示的,但是不同的数据库会选择不同的方式展示给用户。
如下示例分别为 PostgreSQL 数据库、Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示。
```sql
obclient> SELECT /*TPC-DS Q3*/ *
FROM (SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
Sum(ss_net_profit) sum_agg
FROM date_dim dt,
store_sales,
item
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
AND store_sales.ss_item_sk = item.i_item_sk
AND item.i_manufact_id = 914
AND dt.d_moy = 11
GROUP BY dt.d_year,
item.i_brand,
item.i_brand_id
ORDER BY dt.d_year,
sum_agg DESC,
brand_id)
WHERE ROWNUM <= 100;
PostgreSQL 数据库执行计划展示如下:
Limit (cost=13986.86..13987.20 rows=27 width=91)
Sort (cost=13986.86..13986.93 rows=27 width=65)
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
Filter: (d_moy = 11)
Sort (cost=12170.87..12177.27 rows=2560 width=65)
Sort Key: store_sales.ss_sold_date_sk
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
Filter: (i_manufact_id = 914)
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
Recheck Cond: (ss_item_sk = item.i_item_sk)
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
Index Cond: (ss_item_sk = item.i_item_sk)
Oracle 数据库执行计划展示如下:
Plan hash value: 2331821367
--------------------------------------------------------------------------------------------------
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
--------------------------------------------------------------------------------------------------
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|* 1 | COUNT STOPKEY | | | | | |
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
--------------------------------------------------------------------------------------------------
OceanBase 数据库执行计划展示如下:
|ID|OPERATOR |NAME |EST. ROWS|COST |
-------------------------------------------------------
|0 |LIMIT | |100 |81141|
|1 | TOP-N SORT | |100 |81127|
|2 | HASH GROUP BY | |2924 |68551|
|3 | HASH JOIN | |2924 |65004|
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|5 | HASH GROUP BY | |2953 |18662|
|6 | NESTED-LOOP JOIN| |2953 |15080|
|7 | TABLE SCAN |ITEM |19 |11841|
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|9 | TABLE SCAN |DT |6088 |29401|
=======================================================
由示例可见OceanBase 数据库的计划展示与 Oracle 数据库类似。
OceanBase 数据库执行计划中的各列的含义如下:
列名 含义
ID 执行树按照前序遍历的方式得到的编号(从 0 开始)。
OPERATOR 操作算子的名称。
NAME 对应表操作的表名(索引名)。
EST. ROWS 估算该操作算子的输出行数。
COST 该操作算子的执行代价(微秒)。
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示。其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示,层次最深的优先执行,层次相同的以特定算子的执行顺序为标准来执行。
问题: update a not exists (b…)
我一开始以为 B是驱动表B的数据挺多的 后来看到NLAJ是说左边的表关联右边的表
所以这个的驱动表是不是实际是A用A的匹配B的这个理解有问题吗
回答: 没错 A 驱动 B的
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表) 选哪张成为左表或右表就基于一些其他考量了,比如数据量, 而anti join/semi join只是对 not exist/exist的一种优化相关的原理和资料网上可以查阅一下
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
而anti join/semi join谁在not exist左面谁就是驱动表。这么理解对吧
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优

View File

@ -4,10 +4,10 @@
import torch
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048):
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
max_new_tokens = int(params.get("max_new_tokens", 1024))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
@ -21,29 +21,29 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
else:
raise TypeError("Stop parameter must be string or list of strings.")
pos = -1
input_ids = tokenizer(prompt).input_ids
output_ids = []
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
stop_word = None
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
@ -57,15 +57,22 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
# print("Partial output:", output)
for stop_str in stop_strings:
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
pos = output.rfind(stop_str)
if pos != -1:
# print("Found stop str: ", output)
output = output[:pos]
# print("Trimmed output: ", output)
stopped = True
stop_word = stop_str
break
else:
pass
# print("Not found")
if stopped:
break
@ -73,7 +80,7 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
del past_key_values
if pos != -1:
return output[:pos]
return output
return output
@torch.inference_mode()

View File

@ -17,17 +17,25 @@ class VicunaRequestLLM(LLM):
if isinstance(stop, list):
stop = stop + ["Observation:"]
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
params = {
"prompt": prompt,
"temperature": 0,
"max_new_tokens": 256,
"temperature": 0.7,
"max_new_tokens": 1024,
"stop": stop
}
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
data=json.dumps(params)
data=json.dumps(params),
)
response.raise_for_status()
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
# if chunk:
# data = json.loads(chunk.decode())
# if data["error_code"] == 0:
# output = data["text"][skip_echo_len:].strip()
# output = self.post_process_code(output)
# yield output
return response.json()["response"]
@property

View File

@ -115,6 +115,5 @@ def embeddings(prompt_request: EmbeddingRequest):
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")

View File

@ -10,6 +10,9 @@ import gradio as gr
import datetime
import requests
from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
@ -29,7 +32,7 @@ from fastchat.utils import (
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("webserver", "webserver.log")
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
@ -38,11 +41,28 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
priority = {
"vicuna-13b": "aaa"
}
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
)
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@ -58,12 +78,10 @@ function() {
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
return (state,
@ -120,26 +138,32 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time()
model_name = LLM_MODEL
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
# 第一轮对话需要加入提示Prompt
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
# prompt 中添加上下文提示
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
@ -226,7 +250,7 @@ def build_single_model_ui():
"""
state = gr.State()
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
@ -247,29 +271,41 @@ def build_single_model_ui():
label="最大输出Token数",
)
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Tabs():
with gr.TabItem("知识问答", elem_id="QA"):
pass
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="发送", visible=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="" "发送", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -278,7 +314,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@ -286,7 +322,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
)
@ -343,6 +379,7 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
dbs = get_database_list()
logger.info(args)
demo = build_webdemo()
demo.queue(

View File

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
embeddings = VicunaEmbeddingLLM()
def knownledge_tovec(filename):
with open(filename, "r") as f:
knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts(
texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]
)
return docsearch