Merge pull request #15 from csunny/dev

Add local vector database
This commit is contained in:
magic.chen
2023-05-07 15:41:56 +08:00
committed by GitHub
22 changed files with 48171 additions and 311 deletions

4
.gitignore vendored
View File

@@ -130,4 +130,6 @@ dmypy.json
# Pyre type checker
.pyre/
.DS_Store
logs
logs
.vectordb

View File

@@ -1,5 +1,7 @@
# DB-GPT
A Open Database-GPT Experiment
A Open Database-GPT Experiment, A fully localized project.
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social)
@@ -52,3 +54,7 @@ python webserver.py
- SQL-diagnosis
总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
# Licence
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)

View File

@@ -3,12 +3,16 @@
import torch
import os
import nltk
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LLM_MODEL_CONFIG = {
@@ -18,6 +22,7 @@ LLM_MODEL_CONFIG = {
}
VECTOR_SEARCH_TOP_K = 5
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048

View File

@@ -146,6 +146,16 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
已知内容:
{context}
问题:
{question}
"""
default_conversation = conv_one_shot
conv_templates = {

View File

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,185 +0,0 @@
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述。用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划。
如果要分析某条 SQL 的性能问题,通常需要先查看 SQL 的执行计划,排查每一步 SQL 执行是否存在问题。所以读懂执行计划是 SQL 优化的先决条件,而了解执行计划的算子是理解 EXPLAIN 命令的关键。
OceanBase 数据库的执行计划命令有三种模式EXPLAIN BASIC、EXPLAIN 和 EXPLAIN EXTENDED。这三种模式对执行计划展现不同粒度的细节信息:
EXPLAIN BASIC 命令用于最基本的计划展示。
EXPLAIN EXTENDED 命令用于最详细的计划展示(通常在排查问题时使用这种展示模式)。
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式。
EXPLAIN 命令格式如下:
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
format_name:
{ TRADITIONAL | JSON }
explainable_stmt:
{ SELECT st
| DELETE statement
| INSERT statement
| REPLACE statement
| UPDATE statement }
EXPLAIN 命令适用于 SELECT、DELETE、INSERT、REPLACE 和 UPDATE 语句,显示优化器所提供的有关语句执行计划的信息,包括如何处理该语句,如何联接表以及以何种顺序联接表等信息。
一般来说,可以使用 EXPLAIN EXTENDED 命令,将表扫描的范围段展示出来。使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息。
FORMAT 选项可用于选择输出格式。TRADITIONAL 表示以表格格式显示输出这也是默认设置。JSON 表示以 JSON 格式显示信息。
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询。如果检查针对非分区表的查询,则不会产生错误,但 PARTIONS 列的值始终为 NULL。
对于复杂的执行计划,可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来,使得执行计划展示更方便阅读。示例如下:
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
Query OK, 0 rows affected
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
Query OK, 0 rows affected
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
*************************** 1. row ***************************
Query Plan: ==========================================================
|ID|OPERATOR |NAME |EST. ROWS|COST|
----------------------------------------------------------
|0 |PX COORDINATOR | |1 |278 |
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|2 | HASH JOIN | |1 |276 |
|3 | ├PX PARTITION ITERATOR | |1 |92 |
|4 | │ TABLE SCAN |P1 |1 |92 |
|5 | └EXCHANGE IN DISTR | |1 |184 |
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|7 | PX PARTITION ITERATOR | |1 |183 |
|8 | TABLE SCAN |P2 |1 |183 |
==========================================================
Outputs & filters:
-------------------------------------
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
3 - output([P1.C1], [P1.C2]), filter(nil)
4 - output([P1.C1], [P1.C2]), filter(nil),
access([P1.C1], [P1.C2]), partitions(p[0-1])
5 - output([P2.C2], [P2.C1]), filter(nil)
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
7 - output([P2.C1], [P2.C2]), filter(nil)
8 - output([P2.C1], [P2.C2]), filter(nil),
access([P2.C1], [P2.C2]), partitions(p[0-3])
1 row in set
## 执行计划形状与算子信息
在数据库系统中,执行计划在内部通常是以树的形式来表示的,但是不同的数据库会选择不同的方式展示给用户。
如下示例分别为 PostgreSQL 数据库、Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示。
```sql
obclient> SELECT /*TPC-DS Q3*/ *
FROM (SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
Sum(ss_net_profit) sum_agg
FROM date_dim dt,
store_sales,
item
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
AND store_sales.ss_item_sk = item.i_item_sk
AND item.i_manufact_id = 914
AND dt.d_moy = 11
GROUP BY dt.d_year,
item.i_brand,
item.i_brand_id
ORDER BY dt.d_year,
sum_agg DESC,
brand_id)
WHERE ROWNUM <= 100;
PostgreSQL 数据库执行计划展示如下:
Limit (cost=13986.86..13987.20 rows=27 width=91)
Sort (cost=13986.86..13986.93 rows=27 width=65)
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
Filter: (d_moy = 11)
Sort (cost=12170.87..12177.27 rows=2560 width=65)
Sort Key: store_sales.ss_sold_date_sk
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
Filter: (i_manufact_id = 914)
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
Recheck Cond: (ss_item_sk = item.i_item_sk)
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
Index Cond: (ss_item_sk = item.i_item_sk)
Oracle 数据库执行计划展示如下:
Plan hash value: 2331821367
--------------------------------------------------------------------------------------------------
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
--------------------------------------------------------------------------------------------------
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|* 1 | COUNT STOPKEY | | | | | |
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
--------------------------------------------------------------------------------------------------
OceanBase 数据库执行计划展示如下:
|ID|OPERATOR |NAME |EST. ROWS|COST |
-------------------------------------------------------
|0 |LIMIT | |100 |81141|
|1 | TOP-N SORT | |100 |81127|
|2 | HASH GROUP BY | |2924 |68551|
|3 | HASH JOIN | |2924 |65004|
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|5 | HASH GROUP BY | |2953 |18662|
|6 | NESTED-LOOP JOIN| |2953 |15080|
|7 | TABLE SCAN |ITEM |19 |11841|
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|9 | TABLE SCAN |DT |6088 |29401|
=======================================================
由示例可见OceanBase 数据库的计划展示与 Oracle 数据库类似。
OceanBase 数据库执行计划中的各列的含义如下:
列名 含义
ID 执行树按照前序遍历的方式得到的编号(从 0 开始)。
OPERATOR 操作算子的名称。
NAME 对应表操作的表名(索引名)。
EST. ROWS 估算该操作算子的输出行数。
COST 该操作算子的执行代价(微秒)。
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示。其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示,层次最深的优先执行,层次相同的以特定算子的执行顺序为标准来执行。
问题: update a not exists (b…)
我一开始以为 B是驱动表B的数据挺多的 后来看到NLAJ是说左边的表关联右边的表
所以这个的驱动表是不是实际是A用A的匹配B的这个理解有问题吗
回答: 没错 A 驱动 B的
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表) 选哪张成为左表或右表就基于一些其他考量了,比如数据量, 而anti join/semi join只是对 not exist/exist的一种优化相关的原理和资料网上可以查阅一下
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
而anti join/semi join谁在not exist左面谁就是驱动表。这么理解对吧
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优

View File

@@ -3,6 +3,71 @@
import torch
@torch.inference_mode()
def generate_stream(model, tokenizer, params, device,
context_len=2048, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break
del past_key_values
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"]

View File

@@ -5,6 +5,7 @@ import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModel
)
from fastchat.serve.compression import compress_module
@@ -23,20 +24,39 @@ class ModerLoader:
"device_map": "auto",
}
def loader(self, load_8bit=False, debug=False):
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
def loader(self, num_gpus, load_8bit=False, debug=False):
if self.device == "cpu":
kwargs = {}
elif self.device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus == "auto":
kwargs["device_map"] = "auto"
else:
num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs.update({
"device_map": "auto",
"max_memory": {i: "13GiB" for i in range(num_gpus)},
})
else:
raise ValueError(f"Invalid device: {self.device}")
if "chatglm" in self.model_path:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
else:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path,
low_cpu_mem_usage=True, **kwargs)
if load_8bit:
compress_module(model, self.device)
if (self.device == "cuda" and num_gpus == 1):
model.to(self.device)
if debug:
print(model)
if load_8bit:
compress_module(model, self.device)
# if self.device == "cuda":
# model.to(self.device)
return model, tokenizer

View File

@@ -1,56 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -1,3 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
code_highlight_css = (
"""
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
#chatbot .k { color: #008000; font-weight: bold }
#chatbot .o { color: #666666 }
#chatbot .ch { color: #408080; font-style: italic }
#chatbot .cm { color: #408080; font-style: italic }
#chatbot .cp { color: #BC7A00 }
#chatbot .cpf { color: #408080; font-style: italic }
#chatbot .c1 { color: #408080; font-style: italic }
#chatbot .cs { color: #408080; font-style: italic }
#chatbot .gd { color: #A00000 }
#chatbot .ge { font-style: italic }
#chatbot .gr { color: #FF0000 }
#chatbot .gh { color: #000080; font-weight: bold }
#chatbot .gi { color: #00A000 }
#chatbot .go { color: #888888 }
#chatbot .gp { color: #000080; font-weight: bold }
#chatbot .gs { font-weight: bold }
#chatbot .gu { color: #800080; font-weight: bold }
#chatbot .gt { color: #0044DD }
#chatbot .kc { color: #008000; font-weight: bold }
#chatbot .kd { color: #008000; font-weight: bold }
#chatbot .kn { color: #008000; font-weight: bold }
#chatbot .kp { color: #008000 }
#chatbot .kr { color: #008000; font-weight: bold }
#chatbot .kt { color: #B00040 }
#chatbot .m { color: #666666 }
#chatbot .s { color: #BA2121 }
#chatbot .na { color: #7D9029 }
#chatbot .nb { color: #008000 }
#chatbot .nc { color: #0000FF; font-weight: bold }
#chatbot .no { color: #880000 }
#chatbot .nd { color: #AA22FF }
#chatbot .ni { color: #999999; font-weight: bold }
#chatbot .ne { color: #D2413A; font-weight: bold }
#chatbot .nf { color: #0000FF }
#chatbot .nl { color: #A0A000 }
#chatbot .nn { color: #0000FF; font-weight: bold }
#chatbot .nt { color: #008000; font-weight: bold }
#chatbot .nv { color: #19177C }
#chatbot .ow { color: #AA22FF; font-weight: bold }
#chatbot .w { color: #bbbbbb }
#chatbot .mb { color: #666666 }
#chatbot .mf { color: #666666 }
#chatbot .mh { color: #666666 }
#chatbot .mi { color: #666666 }
#chatbot .mo { color: #666666 }
#chatbot .sa { color: #BA2121 }
#chatbot .sb { color: #BA2121 }
#chatbot .sc { color: #BA2121 }
#chatbot .dl { color: #BA2121 }
#chatbot .sd { color: #BA2121; font-style: italic }
#chatbot .s2 { color: #BA2121 }
#chatbot .se { color: #BB6622; font-weight: bold }
#chatbot .sh { color: #BA2121 }
#chatbot .si { color: #BB6688; font-weight: bold }
#chatbot .sx { color: #008000 }
#chatbot .sr { color: #BB6688 }
#chatbot .s1 { color: #BA2121 }
#chatbot .ss { color: #19177C }
#chatbot .bp { color: #008000 }
#chatbot .fm { color: #0000FF }
#chatbot .vc { color: #19177C }
#chatbot .vg { color: #19177C }
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
""")
#.highlight { background: #f8f8f8; }

View File

@@ -0,0 +1,167 @@
"""
Fork from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/gradio_patch.py
"""
from __future__ import annotations
from gradio.components import *
from markdown2 import Markdown
class _Keywords(Enum):
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
@document("style")
class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
"""
Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
Preprocessing: this component does *not* accept input.
Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
Demos: chatbot_simple, chatbot_multimodal
"""
def __init__(
self,
value: List[Tuple[str | None, str | None]] | Callable | None = None,
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
*,
label: str | None = None,
every: float | None = None,
show_label: bool = True,
visible: bool = True,
elem_id: str | None = None,
elem_classes: List[str] | str | None = None,
**kwargs,
):
"""
Parameters:
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
label: component name in interface.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
show_label: if True, will display label.
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
if color_map is not None:
warnings.warn(
"The 'color_map' parameter has been deprecated.",
)
#self.md = utils.get_markdown_parser()
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
self.select: EventListenerMethod
"""
Event listener for when the user selects message from Chatbot.
Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
See EventData documentation on how to use this event data.
"""
IOComponent.__init__(
self,
label=label,
every=every,
show_label=show_label,
visible=visible,
elem_id=elem_id,
elem_classes=elem_classes,
value=value,
**kwargs,
)
def get_config(self):
return {
"value": self.value,
"selectable": self.selectable,
**IOComponent.get_config(self),
}
@staticmethod
def update(
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
label: str | None = None,
show_label: bool | None = None,
visible: bool | None = None,
):
updated_config = {
"label": label,
"show_label": show_label,
"visible": visible,
"value": value,
"__type__": "update",
}
return updated_config
def _process_chat_messages(
self, chat_message: str | Tuple | List | Dict | None
) -> str | Dict | None:
if chat_message is None:
return None
elif isinstance(chat_message, (tuple, list)):
mime_type = processing_utils.get_mimetype(chat_message[0])
return {
"name": chat_message[0],
"mime_type": mime_type,
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
"data": None, # These last two fields are filled in by the frontend
"is_file": True,
}
elif isinstance(
chat_message, dict
): # This happens for previously processed messages
return chat_message
elif isinstance(chat_message, str):
#return self.md.render(chat_message)
return str(self.md.convert(chat_message))
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
def postprocess(
self,
y: List[
Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
],
) -> List[Tuple[str | Dict | None, str | Dict | None]]:
"""
Parameters:
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
Returns:
List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
"""
if y is None:
return []
processed_messages = []
for message_pair in y:
assert isinstance(
message_pair, (tuple, list)
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
assert (
len(message_pair) == 2
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
(
#self._process_chat_messages(message_pair[0]),
'<pre style="font-family: var(--font)">' +
message_pair[0] + "</pre>",
self._process_chat_messages(message_pair[1]),
)
)
return processed_messages
def style(self, height: int | None = None, **kwargs):
"""
This method can be used to change the appearance of the Chatbot component.
"""
if height is not None:
self._style["height"] = height
if kwargs.get("color_map") is not None:
warnings.warn("The 'color_map' parameter has been deprecated.")
Component.style(
self,
**kwargs,
)
return self

View File

@@ -1,50 +0,0 @@
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
import json
import torch
import gradio as gr
from fastchat.serve.inference import generate_stream
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODE,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
def generate(prompt):
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
output = generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2)
for chunk in output:
yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qk_prompt_template
from langchain.chains import RetrievalQA
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
class KnownLedgeBaseQA:
llm: object = None
def __init__(self) -> None:
k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store()
def get_answer(self, query):
prompt_template = conv_qk_prompt_template
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
knownledge_chain = RetrievalQA.from_llm(
llm=self.llm,
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
prompt=prompt
)
knownledge_chain.return_source_documents = True
result = knownledge_chain({"query": query})
yield result

View File

@@ -7,10 +7,12 @@ import json
from typing import Optional, List
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastchat.serve.inference import generate_stream
from pilot.model.inference import generate_stream
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
@@ -20,9 +22,9 @@ model_path = LLM_MODEL_CONFIG[LLM_MODEL]
global_counter = 0
model_semaphore = None
# ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker:
def __init__(self):

View File

@@ -29,8 +29,8 @@ from fastchat.utils import (
moderation_msg
)
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -281,7 +281,7 @@ def build_single_model_ui():
"""
learn_more_markdown = """
### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
"""
state = gr.State()

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import copy
from typing import Optional, List, Dict
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
from langchain.chains import VectorDBQA
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
class KnownLedge2Vector:
embeddings: object = None
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, model_name=None) -> None:
if not model_name:
# use default embedding model
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
def init_vector_store(self):
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print("向量数据库持久化地址: ", persist_dir)
if os.path.exists(persist_dir):
# 从本地持久化文件中Load
print("从本地向量加载数据...")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
# vector_store.add_documents(documents=documents)
else:
documents = self.load_knownlege()
# 重新初始化
vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir)
vector_store.persist()
return vector_store
def load_knownlege(self):
docments = []
for root, _, files in os.walk(DATASETS_DIR, topdown=False):
for file in files:
filename = os.path.join(root, file)
docs = self._load_file(filename)
# 更新metadata数据
new_docs = []
for doc in docs:
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
print("文档2向量初始化中, 请稍等...", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
def _load_file(self, filename):
# 加载文件
if filename.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filename)
text_splitor = CharacterTextSplitter()
docs = loader.load_and_split(text_splitor)
else:
loader = UnstructuredFileLoader(filename, mode="elements")
text_splitor = CharacterTextSplitter()
docs = loader.load_and_split(text_splitor)
return docs
def _load_from_url(self, url):
"""Load data from url address"""
pass
def query(self, q):
"""Query similar doc from Vector """
vector_store = self.init_vector_store()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc
if __name__ == "__main__":
k2v = KnownLedge2Vector()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print(persist_dir)
for s, dc in k2v.query("什么是OceanBase"):
print(s, dc.page_content, dc.metadata)