mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 21:08:59 +00:00
2
.gitignore
vendored
2
.gitignore
vendored
@@ -131,3 +131,5 @@ dmypy.json
|
||||
.pyre/
|
||||
.DS_Store
|
||||
logs
|
||||
|
||||
.vectordb
|
@@ -1,5 +1,7 @@
|
||||
# DB-GPT
|
||||
A Open Database-GPT Experiment
|
||||
A Open Database-GPT Experiment, A fully localized project.
|
||||
|
||||
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
|
||||
|
||||

|
||||
|
||||
@@ -52,3 +54,7 @@ python webserver.py
|
||||
- SQL-diagnosis
|
||||
|
||||
总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
|
||||
|
||||
# Licence
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
@@ -3,12 +3,16 @@
|
||||
|
||||
import torch
|
||||
import os
|
||||
import nltk
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LLM_MODEL_CONFIG = {
|
||||
@@ -18,6 +22,7 @@ LLM_MODEL_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
|
@@ -146,6 +146,16 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
|
||||
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conv_templates = {
|
||||
|
0
pilot/datasets/mysql/url.md
Normal file
0
pilot/datasets/mysql/url.md
Normal file
BIN
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0--快速入门系列课程.pdf
Normal file
BIN
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0--快速入门系列课程.pdf
Normal file
Binary file not shown.
BIN
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0--文档概览.pdf
Normal file
BIN
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0--文档概览.pdf
Normal file
Binary file not shown.
14963
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-OceanBase-介绍.pdf
Normal file
14963
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-OceanBase-介绍.pdf
Normal file
File diff suppressed because one or more lines are too long
16552
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-快速上手.pdf
Normal file
16552
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-快速上手.pdf
Normal file
File diff suppressed because one or more lines are too long
16103
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-版本发布记录.pdf
Normal file
16103
pilot/datasets/oceanbase/OceanBase-数据库-V4.0.0-版本发布记录.pdf
Normal file
File diff suppressed because one or more lines are too long
@@ -1,185 +0,0 @@
|
||||
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述。用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划。
|
||||
|
||||
如果要分析某条 SQL 的性能问题,通常需要先查看 SQL 的执行计划,排查每一步 SQL 执行是否存在问题。所以读懂执行计划是 SQL 优化的先决条件,而了解执行计划的算子是理解 EXPLAIN 命令的关键。
|
||||
|
||||
OceanBase 数据库的执行计划命令有三种模式:EXPLAIN BASIC、EXPLAIN 和 EXPLAIN EXTENDED。这三种模式对执行计划展现不同粒度的细节信息:
|
||||
|
||||
EXPLAIN BASIC 命令用于最基本的计划展示。
|
||||
|
||||
EXPLAIN EXTENDED 命令用于最详细的计划展示(通常在排查问题时使用这种展示模式)。
|
||||
|
||||
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式。
|
||||
|
||||
EXPLAIN 命令格式如下:
|
||||
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
|
||||
format_name:
|
||||
{ TRADITIONAL | JSON }
|
||||
explainable_stmt:
|
||||
{ SELECT st
|
||||
| DELETE statement
|
||||
| INSERT statement
|
||||
| REPLACE statement
|
||||
| UPDATE statement }
|
||||
|
||||
|
||||
EXPLAIN 命令适用于 SELECT、DELETE、INSERT、REPLACE 和 UPDATE 语句,显示优化器所提供的有关语句执行计划的信息,包括如何处理该语句,如何联接表以及以何种顺序联接表等信息。
|
||||
|
||||
一般来说,可以使用 EXPLAIN EXTENDED 命令,将表扫描的范围段展示出来。使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息。
|
||||
|
||||
FORMAT 选项可用于选择输出格式。TRADITIONAL 表示以表格格式显示输出,这也是默认设置。JSON 表示以 JSON 格式显示信息。
|
||||
|
||||
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询。如果检查针对非分区表的查询,则不会产生错误,但 PARTIONS 列的值始终为 NULL。
|
||||
|
||||
对于复杂的执行计划,可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来,使得执行计划展示更方便阅读。示例如下:
|
||||
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
|
||||
Query OK, 0 rows affected
|
||||
|
||||
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
|
||||
Query OK, 0 rows affected
|
||||
|
||||
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
|
||||
*************************** 1. row ***************************
|
||||
Query Plan: ==========================================================
|
||||
|ID|OPERATOR |NAME |EST. ROWS|COST|
|
||||
----------------------------------------------------------
|
||||
|0 |PX COORDINATOR | |1 |278 |
|
||||
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|
||||
|2 | HASH JOIN | |1 |276 |
|
||||
|3 | ├PX PARTITION ITERATOR | |1 |92 |
|
||||
|4 | │ TABLE SCAN |P1 |1 |92 |
|
||||
|5 | └EXCHANGE IN DISTR | |1 |184 |
|
||||
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|
||||
|7 | PX PARTITION ITERATOR | |1 |183 |
|
||||
|8 | TABLE SCAN |P2 |1 |183 |
|
||||
==========================================================
|
||||
|
||||
Outputs & filters:
|
||||
-------------------------------------
|
||||
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
|
||||
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
|
||||
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
|
||||
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
|
||||
3 - output([P1.C1], [P1.C2]), filter(nil)
|
||||
4 - output([P1.C1], [P1.C2]), filter(nil),
|
||||
access([P1.C1], [P1.C2]), partitions(p[0-1])
|
||||
5 - output([P2.C2], [P2.C1]), filter(nil)
|
||||
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
|
||||
7 - output([P2.C1], [P2.C2]), filter(nil)
|
||||
8 - output([P2.C1], [P2.C2]), filter(nil),
|
||||
access([P2.C1], [P2.C2]), partitions(p[0-3])
|
||||
|
||||
1 row in set
|
||||
|
||||
|
||||
|
||||
|
||||
## 执行计划形状与算子信息
|
||||
|
||||
在数据库系统中,执行计划在内部通常是以树的形式来表示的,但是不同的数据库会选择不同的方式展示给用户。
|
||||
|
||||
如下示例分别为 PostgreSQL 数据库、Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示。
|
||||
|
||||
```sql
|
||||
obclient> SELECT /*TPC-DS Q3*/ *
|
||||
FROM (SELECT dt.d_year,
|
||||
item.i_brand_id brand_id,
|
||||
item.i_brand brand,
|
||||
Sum(ss_net_profit) sum_agg
|
||||
FROM date_dim dt,
|
||||
store_sales,
|
||||
item
|
||||
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
|
||||
AND store_sales.ss_item_sk = item.i_item_sk
|
||||
AND item.i_manufact_id = 914
|
||||
AND dt.d_moy = 11
|
||||
GROUP BY dt.d_year,
|
||||
item.i_brand,
|
||||
item.i_brand_id
|
||||
ORDER BY dt.d_year,
|
||||
sum_agg DESC,
|
||||
brand_id)
|
||||
WHERE ROWNUM <= 100;
|
||||
|
||||
PostgreSQL 数据库执行计划展示如下:
|
||||
Limit (cost=13986.86..13987.20 rows=27 width=91)
|
||||
Sort (cost=13986.86..13986.93 rows=27 width=65)
|
||||
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
|
||||
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
|
||||
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
|
||||
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
|
||||
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
|
||||
Filter: (d_moy = 11)
|
||||
Sort (cost=12170.87..12177.27 rows=2560 width=65)
|
||||
Sort Key: store_sales.ss_sold_date_sk
|
||||
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
|
||||
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
|
||||
Filter: (i_manufact_id = 914)
|
||||
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
|
||||
Recheck Cond: (ss_item_sk = item.i_item_sk)
|
||||
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
|
||||
Index Cond: (ss_item_sk = item.i_item_sk)
|
||||
|
||||
|
||||
|
||||
Oracle 数据库执行计划展示如下:
|
||||
Plan hash value: 2331821367
|
||||
--------------------------------------------------------------------------------------------------
|
||||
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
|
||||
--------------------------------------------------------------------------------------------------
|
||||
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|
||||
|* 1 | COUNT STOPKEY | | | | | |
|
||||
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|
||||
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|
||||
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|
||||
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|
||||
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
|
||||
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|
||||
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|
||||
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|
||||
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
|
||||
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
|
||||
--------------------------------------------------------------------------------------------------
|
||||
|
||||
OceanBase 数据库执行计划展示如下:
|
||||
|ID|OPERATOR |NAME |EST. ROWS|COST |
|
||||
-------------------------------------------------------
|
||||
|0 |LIMIT | |100 |81141|
|
||||
|1 | TOP-N SORT | |100 |81127|
|
||||
|2 | HASH GROUP BY | |2924 |68551|
|
||||
|3 | HASH JOIN | |2924 |65004|
|
||||
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|
||||
|5 | HASH GROUP BY | |2953 |18662|
|
||||
|6 | NESTED-LOOP JOIN| |2953 |15080|
|
||||
|7 | TABLE SCAN |ITEM |19 |11841|
|
||||
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|
||||
|9 | TABLE SCAN |DT |6088 |29401|
|
||||
=======================================================
|
||||
|
||||
由示例可见,OceanBase 数据库的计划展示与 Oracle 数据库类似。
|
||||
|
||||
OceanBase 数据库执行计划中的各列的含义如下:
|
||||
列名 含义
|
||||
ID 执行树按照前序遍历的方式得到的编号(从 0 开始)。
|
||||
OPERATOR 操作算子的名称。
|
||||
NAME 对应表操作的表名(索引名)。
|
||||
EST. ROWS 估算该操作算子的输出行数。
|
||||
COST 该操作算子的执行代价(微秒)。
|
||||
|
||||
|
||||
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示。其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示,层次最深的优先执行,层次相同的以特定算子的执行顺序为标准来执行。
|
||||
|
||||
问题: update a not exists (b…)
|
||||
我一开始以为 B是驱动表,B的数据挺多的 后来看到NLAJ,是说左边的表关联右边的表
|
||||
所以这个的驱动表是不是实际是A,用A的匹配B的,这个理解有问题吗
|
||||
|
||||
回答: 没错 A 驱动 B的
|
||||
|
||||
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
|
||||
|
||||
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表), 选哪张成为左表或右表就基于一些其他考量了,比如数据量, 而anti join/semi join只是对 not exist/exist的一种优化,相关的原理和资料网上可以查阅一下
|
||||
|
||||
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
|
||||
而anti join/semi join,谁在not exist左面,谁就是驱动表。这么理解对吧
|
||||
|
||||
回答: nlj也是左表的表是驱动表,这个要了解下计划执行方面的基本原理,取左表的一行数据,再遍历右表,一旦满足连接条件,就可以返回数据
|
||||
anti/semi只是因为not exists/exist的语义只是返回左表数据,改成anti join是一种计划优化,连接的方式比子查询更优
|
@@ -3,6 +3,71 @@
|
||||
|
||||
import torch
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params, device,
|
||||
context_len=2048, stream_interval=2):
|
||||
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(
|
||||
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device)
|
||||
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
yield output
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
prompt = params["prompt"]
|
||||
|
@@ -5,6 +5,7 @@ import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
from fastchat.serve.compression import compress_module
|
||||
@@ -23,20 +24,39 @@ class ModerLoader:
|
||||
"device_map": "auto",
|
||||
}
|
||||
|
||||
def loader(self, load_8bit=False, debug=False):
|
||||
def loader(self, num_gpus, load_8bit=False, debug=False):
|
||||
if self.device == "cpu":
|
||||
kwargs = {}
|
||||
elif self.device == "cuda":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
if num_gpus == "auto":
|
||||
kwargs["device_map"] = "auto"
|
||||
else:
|
||||
num_gpus = int(num_gpus)
|
||||
if num_gpus != 1:
|
||||
kwargs.update({
|
||||
"device_map": "auto",
|
||||
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
||||
})
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {self.device}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
if "chatglm" in self.model_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_path,
|
||||
low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
if load_8bit:
|
||||
compress_module(model, self.device)
|
||||
|
||||
# if self.device == "cuda":
|
||||
# model.to(self.device)
|
||||
if (self.device == "cuda" and num_gpus == 1):
|
||||
model.to(self.device)
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
@@ -1,56 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from pilot.configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
|
||||
def generate(prompt):
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
@@ -1,3 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from pilot.configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
|
||||
def generate(prompt):
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
76
pilot/server/gradio_css.py
Normal file
76
pilot/server/gradio_css.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
code_highlight_css = (
|
||||
"""
|
||||
#chatbot .hll { background-color: #ffffcc }
|
||||
#chatbot .c { color: #408080; font-style: italic }
|
||||
#chatbot .err { border: 1px solid #FF0000 }
|
||||
#chatbot .k { color: #008000; font-weight: bold }
|
||||
#chatbot .o { color: #666666 }
|
||||
#chatbot .ch { color: #408080; font-style: italic }
|
||||
#chatbot .cm { color: #408080; font-style: italic }
|
||||
#chatbot .cp { color: #BC7A00 }
|
||||
#chatbot .cpf { color: #408080; font-style: italic }
|
||||
#chatbot .c1 { color: #408080; font-style: italic }
|
||||
#chatbot .cs { color: #408080; font-style: italic }
|
||||
#chatbot .gd { color: #A00000 }
|
||||
#chatbot .ge { font-style: italic }
|
||||
#chatbot .gr { color: #FF0000 }
|
||||
#chatbot .gh { color: #000080; font-weight: bold }
|
||||
#chatbot .gi { color: #00A000 }
|
||||
#chatbot .go { color: #888888 }
|
||||
#chatbot .gp { color: #000080; font-weight: bold }
|
||||
#chatbot .gs { font-weight: bold }
|
||||
#chatbot .gu { color: #800080; font-weight: bold }
|
||||
#chatbot .gt { color: #0044DD }
|
||||
#chatbot .kc { color: #008000; font-weight: bold }
|
||||
#chatbot .kd { color: #008000; font-weight: bold }
|
||||
#chatbot .kn { color: #008000; font-weight: bold }
|
||||
#chatbot .kp { color: #008000 }
|
||||
#chatbot .kr { color: #008000; font-weight: bold }
|
||||
#chatbot .kt { color: #B00040 }
|
||||
#chatbot .m { color: #666666 }
|
||||
#chatbot .s { color: #BA2121 }
|
||||
#chatbot .na { color: #7D9029 }
|
||||
#chatbot .nb { color: #008000 }
|
||||
#chatbot .nc { color: #0000FF; font-weight: bold }
|
||||
#chatbot .no { color: #880000 }
|
||||
#chatbot .nd { color: #AA22FF }
|
||||
#chatbot .ni { color: #999999; font-weight: bold }
|
||||
#chatbot .ne { color: #D2413A; font-weight: bold }
|
||||
#chatbot .nf { color: #0000FF }
|
||||
#chatbot .nl { color: #A0A000 }
|
||||
#chatbot .nn { color: #0000FF; font-weight: bold }
|
||||
#chatbot .nt { color: #008000; font-weight: bold }
|
||||
#chatbot .nv { color: #19177C }
|
||||
#chatbot .ow { color: #AA22FF; font-weight: bold }
|
||||
#chatbot .w { color: #bbbbbb }
|
||||
#chatbot .mb { color: #666666 }
|
||||
#chatbot .mf { color: #666666 }
|
||||
#chatbot .mh { color: #666666 }
|
||||
#chatbot .mi { color: #666666 }
|
||||
#chatbot .mo { color: #666666 }
|
||||
#chatbot .sa { color: #BA2121 }
|
||||
#chatbot .sb { color: #BA2121 }
|
||||
#chatbot .sc { color: #BA2121 }
|
||||
#chatbot .dl { color: #BA2121 }
|
||||
#chatbot .sd { color: #BA2121; font-style: italic }
|
||||
#chatbot .s2 { color: #BA2121 }
|
||||
#chatbot .se { color: #BB6622; font-weight: bold }
|
||||
#chatbot .sh { color: #BA2121 }
|
||||
#chatbot .si { color: #BB6688; font-weight: bold }
|
||||
#chatbot .sx { color: #008000 }
|
||||
#chatbot .sr { color: #BB6688 }
|
||||
#chatbot .s1 { color: #BA2121 }
|
||||
#chatbot .ss { color: #19177C }
|
||||
#chatbot .bp { color: #008000 }
|
||||
#chatbot .fm { color: #0000FF }
|
||||
#chatbot .vc { color: #19177C }
|
||||
#chatbot .vg { color: #19177C }
|
||||
#chatbot .vi { color: #19177C }
|
||||
#chatbot .vm { color: #19177C }
|
||||
#chatbot .il { color: #666666 }
|
||||
""")
|
||||
#.highlight { background: #f8f8f8; }
|
||||
|
167
pilot/server/gradio_patch.py
Normal file
167
pilot/server/gradio_patch.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Fork from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/gradio_patch.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from gradio.components import *
|
||||
from markdown2 import Markdown
|
||||
|
||||
|
||||
class _Keywords(Enum):
|
||||
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
|
||||
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
|
||||
|
||||
|
||||
@document("style")
|
||||
class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
"""
|
||||
Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
|
||||
|
||||
Demos: chatbot_simple, chatbot_multimodal
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: List[Tuple[str | None, str | None]] | Callable | None = None,
|
||||
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
|
||||
*,
|
||||
label: str | None = None,
|
||||
every: float | None = None,
|
||||
show_label: bool = True,
|
||||
visible: bool = True,
|
||||
elem_id: str | None = None,
|
||||
elem_classes: List[str] | str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
|
||||
label: component name in interface.
|
||||
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
||||
show_label: if True, will display label.
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
if color_map is not None:
|
||||
warnings.warn(
|
||||
"The 'color_map' parameter has been deprecated.",
|
||||
)
|
||||
#self.md = utils.get_markdown_parser()
|
||||
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
||||
self.select: EventListenerMethod
|
||||
"""
|
||||
Event listener for when the user selects message from Chatbot.
|
||||
Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
|
||||
See EventData documentation on how to use this event data.
|
||||
"""
|
||||
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
every=every,
|
||||
show_label=show_label,
|
||||
visible=visible,
|
||||
elem_id=elem_id,
|
||||
elem_classes=elem_classes,
|
||||
value=value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"value": self.value,
|
||||
"selectable": self.selectable,
|
||||
**IOComponent.get_config(self),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
|
||||
label: str | None = None,
|
||||
show_label: bool | None = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
updated_config = {
|
||||
"label": label,
|
||||
"show_label": show_label,
|
||||
"visible": visible,
|
||||
"value": value,
|
||||
"__type__": "update",
|
||||
}
|
||||
return updated_config
|
||||
|
||||
def _process_chat_messages(
|
||||
self, chat_message: str | Tuple | List | Dict | None
|
||||
) -> str | Dict | None:
|
||||
if chat_message is None:
|
||||
return None
|
||||
elif isinstance(chat_message, (tuple, list)):
|
||||
mime_type = processing_utils.get_mimetype(chat_message[0])
|
||||
return {
|
||||
"name": chat_message[0],
|
||||
"mime_type": mime_type,
|
||||
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
||||
"data": None, # These last two fields are filled in by the frontend
|
||||
"is_file": True,
|
||||
}
|
||||
elif isinstance(
|
||||
chat_message, dict
|
||||
): # This happens for previously processed messages
|
||||
return chat_message
|
||||
elif isinstance(chat_message, str):
|
||||
#return self.md.render(chat_message)
|
||||
return str(self.md.convert(chat_message))
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: List[
|
||||
Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
|
||||
],
|
||||
) -> List[Tuple[str | Dict | None, str | Dict | None]]:
|
||||
"""
|
||||
Parameters:
|
||||
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
||||
Returns:
|
||||
List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
|
||||
"""
|
||||
if y is None:
|
||||
return []
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
assert isinstance(
|
||||
message_pair, (tuple, list)
|
||||
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
assert (
|
||||
len(message_pair) == 2
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
processed_messages.append(
|
||||
(
|
||||
#self._process_chat_messages(message_pair[0]),
|
||||
'<pre style="font-family: var(--font)">' +
|
||||
message_pair[0] + "</pre>",
|
||||
self._process_chat_messages(message_pair[1]),
|
||||
)
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def style(self, height: int | None = None, **kwargs):
|
||||
"""
|
||||
This method can be used to change the appearance of the Chatbot component.
|
||||
"""
|
||||
if height is not None:
|
||||
self._style["height"] = height
|
||||
if kwargs.get("color_map") is not None:
|
||||
warnings.warn("The 'color_map' parameter has been deprecated.")
|
||||
|
||||
Component.style(
|
||||
self,
|
||||
**kwargs,
|
||||
)
|
||||
return self
|
||||
|
||||
|
@@ -1,50 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import json
|
||||
import torch
|
||||
import gradio as gr
|
||||
from fastchat.serve.inference import generate_stream
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODE,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
def generate(prompt):
|
||||
model.to(device)
|
||||
print(model, tokenizer)
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
output = generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2)
|
||||
|
||||
for chunk in output:
|
||||
yield chunk
|
||||
|
||||
if __name__ == "__main__":
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
33
pilot/server/vectordb_qa.py
Normal file
33
pilot/server/vectordb_qa.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pilot.conversation import conv_qk_prompt_template
|
||||
from langchain.chains import RetrievalQA
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
|
||||
llm: object = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
k2v = KnownLedge2Vector()
|
||||
self.vector_store = k2v.init_vector_store()
|
||||
|
||||
def get_answer(self, query):
|
||||
prompt_template = conv_qk_prompt_template
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
knownledge_chain = RetrievalQA.from_llm(
|
||||
llm=self.llm,
|
||||
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
|
||||
prompt=prompt
|
||||
)
|
||||
knownledge_chain.return_source_documents = True
|
||||
result = knownledge_chain({"query": query})
|
||||
yield result
|
@@ -7,10 +7,12 @@ import json
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastchat.serve.inference import generate_stream
|
||||
from pilot.model.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
|
||||
|
||||
from pilot.model.loader import ModerLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
@@ -20,9 +22,9 @@ model_path = LLM_MODEL_CONFIG[LLM_MODEL]
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
# ml = ModerLoader(model_path=model_path)
|
||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
ml = ModerLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
|
||||
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self):
|
||||
|
@@ -29,8 +29,8 @@ from fastchat.utils import (
|
||||
moderation_msg
|
||||
)
|
||||
|
||||
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||
from fastchat.serve.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
@@ -281,7 +281,7 @@ def build_single_model_ui():
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
|
||||
"""
|
||||
|
||||
state = gr.State()
|
||||
|
91
pilot/vector_store/file_loader.py
Normal file
91
pilot/vector_store/file_loader.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional, List, Dict
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
||||
from langchain.chains import VectorDBQA
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
class KnownLedge2Vector:
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
# use default embedding model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
def init_vector_store(self):
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print("向量数据库持久化地址: ", persist_dir)
|
||||
if os.path.exists(persist_dir):
|
||||
# 从本地持久化文件中Load
|
||||
print("从本地向量加载数据...")
|
||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||
# vector_store.add_documents(documents=documents)
|
||||
else:
|
||||
documents = self.load_knownlege()
|
||||
# 重新初始化
|
||||
vector_store = Chroma.from_documents(documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=persist_dir)
|
||||
vector_store.persist()
|
||||
return vector_store
|
||||
|
||||
def load_knownlege(self):
|
||||
docments = []
|
||||
for root, _, files in os.walk(DATASETS_DIR, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
docs = self._load_file(filename)
|
||||
# 更新metadata数据
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
|
||||
print("文档2向量初始化中, 请稍等...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
|
||||
def _load_file(self, filename):
|
||||
# 加载文件
|
||||
if filename.lower().endswith(".pdf"):
|
||||
loader = UnstructuredFileLoader(filename)
|
||||
text_splitor = CharacterTextSplitter()
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
else:
|
||||
loader = UnstructuredFileLoader(filename, mode="elements")
|
||||
text_splitor = CharacterTextSplitter()
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
return docs
|
||||
|
||||
def _load_from_url(self, url):
|
||||
"""Load data from url address"""
|
||||
pass
|
||||
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector """
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
|
||||
if __name__ == "__main__":
|
||||
k2v = KnownLedge2Vector()
|
||||
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print(persist_dir)
|
||||
for s, dc in k2v.query("什么是OceanBase"):
|
||||
print(s, dc.page_content, dc.metadata)
|
||||
|
Reference in New Issue
Block a user