Merge pull request #1 from csunny/dev

add sqlgpt and qa
This commit is contained in:
magic.chen
2023-04-29 23:29:43 +08:00
committed by GitHub
23 changed files with 893 additions and 3 deletions

1
.gitignore vendored
View File

@@ -20,6 +20,7 @@ parts/
sdist/
var/
wheels/
models/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/

25
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,25 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {"PYTHONPATH": "${workspaceFolder}"},
"envFile": "${workspaceFolder}/.env"
},
{
"name": "Python: Module",
"type": "python",
"request": "launch",
"module": "pilot",
"justMyCode": true,
}
]
}

View File

@@ -1,10 +1,25 @@
# DB-GPT
A Open Database-GPT Experiment
![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social)
DB-GPT is an experimental open-source application that builds upon the fastchat model and uses vicuna as its base model. Additionally, it looks like this application incorporates langchain and llama-index embedding knowledge to improve Database-QA capabilities.
DB-GPT is an experimental open-source application, which based on the vicuna base model.
Overall, it appears to be a sophisticated and innovative tool for working with databases. If you have any specific questions about how to use or implement DB-GPT in your work, please let me know and I'll do my best to assist you.
# Install
1. Run model server
```
cd pilot/server
uvicorn vicuna_server:app --host 0.0.0.0
```
## Featurs
Coming soon, please wait...
2. Run gradio webui
```
python app.py
```
# Featurs
- SQL-Generate
- Database-QA Based Knowledge
- SQL-diagnosis

0
asserts/readme.md Normal file
View File

1
docs/introduct.md Normal file
View File

@@ -0,0 +1 @@
#

62
environment.yml Normal file
View File

@@ -0,0 +1,62 @@
name: db_pgt
channels:
- pytorch
- defaults
- anaconda
dependencies:
- python=3.9
- cudatoolkit
- pip
- pytorch=1.12.1
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
- accelerate==0.16.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
- attrs==22.2.0
- bitsandbytes==0.37.0
- cchardet==2.1.7
- chardet==5.1.0
- contourpy==1.0.7
- cycler==0.11.0
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- huggingface-hub==0.13.4
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- matplotlib==3.7.0
- multidict==6.0.4
- openai==0.27.0
- packaging==23.0
- psutil==5.9.4
- pycocotools==2.0.6
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pyyaml==6.0
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.28.0
- timm==0.6.13
- spacy==3.5.1
- webdataset==0.2.48
- scikit-learn==1.2.2
- scipy==1.10.1
- yarl==1.8.2
- zipp==3.14.0
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- iopath==0.1.10
- tenacity==8.2.2
- peft
- pycocoevalcap
- sentence-transformers
- umap-learn
- notebook
- gradio==3.24.1
- gradio-client==0.0.8
- wandb

19
examples/gpt_index.py Normal file
View File

@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import logging
import sys
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
# read the document of data dir
documents = SimpleDirectoryReader("data").load_data()
# split the document to chunk, max token size=500, convert chunk to vector
index = GPTSimpleVectorIndex(documents)
# save index
index.save_to_disk("index.json")

240
examples/t5_example.py Normal file
View File

@@ -0,0 +1,240 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LLMPredictor
import torch
from langchain.llms.base import LLM
from transformers import pipeline
class FlanLLM(LLM):
model_name = "google/flan-t5-large"
pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={
"torch_dtype": torch.bfloat16
})
def _call(self, prompt, stop=None):
return self.pipeline(prompt, max_length=9999)[0]["generated_text"]
def _identifying_params(self):
return {"name_of_model": self.model_name}
def _llm_type(self):
return "custome"
llm_predictor = LLMPredictor(llm=FlanLLM())
hfemb = HuggingFaceEmbeddings()
embed_model = LangchainEmbedding(hfemb)
text1 = """
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述。用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划。
如果要分析某条 SQL 的性能问题,通常需要先查看 SQL 的执行计划,排查每一步 SQL 执行是否存在问题。所以读懂执行计划是 SQL 优化的先决条件,而了解执行计划的算子是理解 EXPLAIN 命令的关键。
OceanBase 数据库的执行计划命令有三种模式EXPLAIN BASIC、EXPLAIN 和 EXPLAIN EXTENDED。这三种模式对执行计划展现不同粒度的细节信息:
EXPLAIN BASIC 命令用于最基本的计划展示。
EXPLAIN EXTENDED 命令用于最详细的计划展示(通常在排查问题时使用这种展示模式)。
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式。
EXPLAIN 命令格式如下:
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
format_name:
{ TRADITIONAL | JSON }
explainable_stmt:
{ SELECT statement
| DELETE statement
| INSERT statement
| REPLACE statement
| UPDATE statement }
EXPLAIN 命令适用于 SELECT、DELETE、INSERT、REPLACE 和 UPDATE 语句,显示优化器所提供的有关语句执行计划的信息,包括如何处理该语句,如何联接表以及以何种顺序联接表等信息。
一般来说,可以使用 EXPLAIN EXTENDED 命令,将表扫描的范围段展示出来。使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息。
FORMAT 选项可用于选择输出格式。TRADITIONAL 表示以表格格式显示输出这也是默认设置。JSON 表示以 JSON 格式显示信息。
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询。如果检查针对非分区表的查询,则不会产生错误,但 PARTIONS 列的值始终为 NULL。
对于复杂的执行计划,可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来,使得执行计划展示更方便阅读。示例如下:
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
Query OK, 0 rows affected
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
Query OK, 0 rows affected
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
*************************** 1. row ***************************
Query Plan: ==========================================================
|ID|OPERATOR |NAME |EST. ROWS|COST|
----------------------------------------------------------
|0 |PX COORDINATOR | |1 |278 |
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|2 | HASH JOIN | |1 |276 |
|3 | ├PX PARTITION ITERATOR | |1 |92 |
|4 | │ TABLE SCAN |P1 |1 |92 |
|5 | └EXCHANGE IN DISTR | |1 |184 |
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|7 | PX PARTITION ITERATOR | |1 |183 |
|8 | TABLE SCAN |P2 |1 |183 |
==========================================================
Outputs & filters:
-------------------------------------
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
3 - output([P1.C1], [P1.C2]), filter(nil)
4 - output([P1.C1], [P1.C2]), filter(nil),
access([P1.C1], [P1.C2]), partitions(p[0-1])
5 - output([P2.C2], [P2.C1]), filter(nil)
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
7 - output([P2.C1], [P2.C2]), filter(nil)
8 - output([P2.C1], [P2.C2]), filter(nil),
access([P2.C1], [P2.C2]), partitions(p[0-3])
1 row in set
## 执行计划形状与算子信息
在数据库系统中,执行计划在内部通常是以树的形式来表示的,但是不同的数据库会选择不同的方式展示给用户。
如下示例分别为 PostgreSQL 数据库、Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示。
```sql
obclient> SELECT /*TPC-DS Q3*/ *
FROM (SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
Sum(ss_net_profit) sum_agg
FROM date_dim dt,
store_sales,
item
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
AND store_sales.ss_item_sk = item.i_item_sk
AND item.i_manufact_id = 914
AND dt.d_moy = 11
GROUP BY dt.d_year,
item.i_brand,
item.i_brand_id
ORDER BY dt.d_year,
sum_agg DESC,
brand_id)
WHERE ROWNUM <= 100;
PostgreSQL 数据库执行计划展示如下:
Limit (cost=13986.86..13987.20 rows=27 width=91)
Sort (cost=13986.86..13986.93 rows=27 width=65)
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
Filter: (d_moy = 11)
Sort (cost=12170.87..12177.27 rows=2560 width=65)
Sort Key: store_sales.ss_sold_date_sk
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
Filter: (i_manufact_id = 914)
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
Recheck Cond: (ss_item_sk = item.i_item_sk)
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
Index Cond: (ss_item_sk = item.i_item_sk)
Oracle 数据库执行计划展示如下:
Plan hash value: 2331821367
--------------------------------------------------------------------------------------------------
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
--------------------------------------------------------------------------------------------------
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|* 1 | COUNT STOPKEY | | | | | |
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
--------------------------------------------------------------------------------------------------
OceanBase 数据库执行计划展示如下:
|ID|OPERATOR |NAME |EST. ROWS|COST |
-------------------------------------------------------
|0 |LIMIT | |100 |81141|
|1 | TOP-N SORT | |100 |81127|
|2 | HASH GROUP BY | |2924 |68551|
|3 | HASH JOIN | |2924 |65004|
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|5 | HASH GROUP BY | |2953 |18662|
|6 | NESTED-LOOP JOIN| |2953 |15080|
|7 | TABLE SCAN |ITEM |19 |11841|
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|9 | TABLE SCAN |DT |6088 |29401|
=======================================================
由示例可见OceanBase 数据库的计划展示与 Oracle 数据库类似。
OceanBase 数据库执行计划中的各列的含义如下:
列名 含义
ID 执行树按照前序遍历的方式得到的编号(从 0 开始)。
OPERATOR 操作算子的名称。
NAME 对应表操作的表名(索引名)。
EST. ROWS 估算该操作算子的输出行数。
COST 该操作算子的执行代价(微秒)。
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示。其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示,层次最深的优先执行,层次相同的以特定算子的执行顺序为标准来执行。
问题: update a not exists (b…)
我一开始以为 B是驱动表B的数据挺多的 后来看到NLAJ是说左边的表关联右边的表
所以这个的驱动表是不是实际是A用A的匹配B的这个理解有问题吗
回答: 没错 A 驱动 B的
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表) 选哪张成为左表或右表就基于一些其他考量了,比如数据量, 而anti join/semi join只是对 not exist/exist的一种优化相关的原理和资料网上可以查阅一下
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
而anti join/semi join谁在not exist左面谁就是驱动表。这么理解对吧
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优
"""
from llama_index import Document
text_list = [text1]
documents = [Document(t) for t in text_list]
num_output = 250
max_input_size = 512
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper)
index.save_to_disk("index.json")
if __name__ == "__main__":
import logging
logging.getLogger().setLevel(logging.CRITICAL)
for d in documents:
print(d)
response = index.query("数据库的执行计划命令有多少?")
print(response)

1
pilot/__init__.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.0.1"

55
pilot/app.py Normal file
View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import gradio as gr
from langchain.agents import (
load_tools,
initialize_agent,
AgentType
)
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import Document, GPTSimpleVectorIndex
def agent_demo():
llm = VicunaRequestLLM()
tools = load_tools(['python_repl'], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
agent.run(
"Write a SQL script that Query 'select count(1)!'"
)
def knowledged_qa_demo(text_list):
llm_predictor = LLMPredictor(llm=VicunaRequestLLM())
hfemb = VicunaEmbeddingLLM()
embed_model = LangchainEmbedding(hfemb)
documents = [Document(t) for t in text_list]
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
return index
def get_answer(q):
base_knowledge = """ 这是一段测试文字 """
text_list = [base_knowledge]
index = knowledged_qa_demo(text_list)
response = index.query(q)
return response.response
if __name__ == "__main__":
# agent_demo()
with gr.Blocks() as demo:
gr.Markdown("数据库智能助手")
with gr.Tab("知识问答"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button()
text_button.click(get_answer, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
llm_model_config = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
}
LLM_MODEL = "vicuna-13b"
vicuna_model_server = "http://192.168.31.114:8000/"
# Load model config
isload_8bit = True
isdebug = False

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

0
pilot/model/__init__.py Normal file
View File

87
pilot/model/inference.py Normal file
View File

@@ -0,0 +1,87 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_parameter = params.get("stop", None)
print(tokenizer.__dir__())
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
pos = -1
input_ids = tokenizer(prompt).input_ids
output_ids = []
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_value
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop_strings:
pos = output.rfind(stop_str)
if pos != -1:
output = output[:pos]
stoppped = True
break
else:
pass
if stoppped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids
input_embeddings = model.get_input_embeddings()
embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean

42
pilot/model/loader.py Normal file
View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from fastchat.serve.compression import compress_module
class ModerLoader:
kwargs = {}
def __init__(self,
model_path) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_path = model_path
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
def loader(self, load_8bit=False, debug=False):
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
if debug:
print(model)
if load_8bit:
compress_module(model, self.device)
# if self.device == "cuda":
# model.to(self.device)
return model, tokenizer

84
pilot/model/vicuna_llm.py Normal file
View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import json
import requests
from urllib.parse import urljoin
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM
from configs.model_config import *
class VicunaRequestLLM(LLM):
vicuna_generate_path = "generate"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
if isinstance(stop, list):
stop = stop + ["Observation:"]
params = {
"prompt": prompt,
"temperature": 0,
"max_new_tokens": 256,
"stop": stop
}
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
data=json.dumps(params)
)
response.raise_for_status()
return response.json()["response"]
@property
def _llm_type(self) -> str:
return "custome"
def _identifying_params(self) -> Mapping[str, Any]:
return {}
class VicunaEmbeddingLLM(BaseModel, Embeddings):
vicuna_embedding_path = "embedding"
def _call(self, prompt: str) -> str:
p = prompt.strip()
print("Sending prompt ", p)
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_embedding_path),
json={
"prompt": p
}
)
response.raise_for_status()
return response.json()["response"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
""" Call out to Vicuna's server embedding endpoint for embedding search docs.
Args:
texts: The list of text to embed
Returns:
List of embeddings. one for each text.
"""
results = []
for text in texts:
response = self.embed_query(text)
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
""" Call out to Vicuna's server embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text
"""
embedding = self._call(text)
return embedding

0
pilot/server/__init__.py Normal file
View File

56
pilot/server/chatbot.py Normal file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

50
pilot/server/sqlgpt.py Normal file
View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
import json
import torch
import gradio as gr
from fastchat.serve.inference import generate_stream
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODE,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
def generate(prompt):
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
output = generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2)
for chunk in output:
yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from fastapi import FastAPI
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
stop: Optional[List[str]] = None
class EmbeddingRequest(BaseModel):
prompt: str
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)
return {"response": output}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]}

22
pilot/utils.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
def get_gpu_memory(max_gpus=None):
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024 ** 3)
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory

52
requirements.txt Normal file
View File

@@ -0,0 +1,52 @@
accelerate==0.16.0
torch==2.0.0
torchvision==0.13.1
torchaudio==0.12.1
accelerate==0.16.0
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==22.2.0
bitsandbytes==0.37.0
cchardet==2.1.7
chardet==5.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.9.0
fonttools==4.38.0
frozenlist==1.3.3
huggingface-hub==0.13.4
importlib-resources==5.12.0
kiwisolver==1.4.4
matplotlib==3.7.0
multidict==6.0.4
openai==0.27.0
packaging==23.0
psutil==5.9.4
pycocotools==2.0.6
pyparsing==3.0.9
python-dateutil==2.8.2
pyyaml==6.0
regex==2022.10.31
tokenizers==0.13.2
tqdm==4.64.1
transformers==4.28.0
timm==0.6.13
spacy==3.5.1
webdataset==0.2.48
scikit-learn==1.2.2
scipy==1.10.1
yarl==1.8.2
zipp==3.14.0
omegaconf==2.3.0
opencv-python==4.7.0.72
iopath==0.1.10
tenacity==8.2.2
peft
pycocoevalcap
sentence-transformers
umap-learn
notebook
gradio==3.24.1
gradio-client==0.0.8
wandb