fix merge problem

This commit is contained in:
csunny 2023-04-29 21:29:29 +08:00
commit 1b70b5f92b
11 changed files with 558 additions and 255 deletions

25
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,25 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {"PYTHONPATH": "${workspaceFolder}"},
"envFile": "${workspaceFolder}/.env"
},
{
"name": "Python: Module",
"type": "python",
"request": "launch",
"module": "pilot",
"justMyCode": true,
}
]
}

240
examples/t5_example.py Normal file
View File

@ -0,0 +1,240 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LLMPredictor
import torch
from langchain.llms.base import LLM
from transformers import pipeline
class FlanLLM(LLM):
model_name = "google/flan-t5-large"
pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={
"torch_dtype": torch.bfloat16
})
def _call(self, prompt, stop=None):
return self.pipeline(prompt, max_length=9999)[0]["generated_text"]
def _identifying_params(self):
return {"name_of_model": self.model_name}
def _llm_type(self):
return "custome"
llm_predictor = LLMPredictor(llm=FlanLLM())
hfemb = HuggingFaceEmbeddings()
embed_model = LangchainEmbedding(hfemb)
text1 = """
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划
如果要分析某条 SQL 的性能问题通常需要先查看 SQL 的执行计划排查每一步 SQL 执行是否存在问题所以读懂执行计划是 SQL 优化的先决条件而了解执行计划的算子是理解 EXPLAIN 命令的关键
OceanBase 数据库的执行计划命令有三种模式EXPLAIN BASICEXPLAIN EXPLAIN EXTENDED这三种模式对执行计划展现不同粒度的细节信息:
EXPLAIN BASIC 命令用于最基本的计划展示
EXPLAIN EXTENDED 命令用于最详细的计划展示通常在排查问题时使用这种展示模式
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式
EXPLAIN 命令格式如下
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
format_name:
{ TRADITIONAL | JSON }
explainable_stmt:
{ SELECT statement
| DELETE statement
| INSERT statement
| REPLACE statement
| UPDATE statement }
EXPLAIN 命令适用于 SELECTDELETEINSERTREPLACE UPDATE 语句显示优化器所提供的有关语句执行计划的信息包括如何处理该语句如何联接表以及以何种顺序联接表等信息
一般来说可以使用 EXPLAIN EXTENDED 命令将表扫描的范围段展示出来使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息
FORMAT 选项可用于选择输出格式TRADITIONAL 表示以表格格式显示输出这也是默认设置JSON 表示以 JSON 格式显示信息
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询如果检查针对非分区表的查询则不会产生错误 PARTIONS 列的值始终为 NULL
对于复杂的执行计划可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来使得执行计划展示更方便阅读示例如下
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
Query OK, 0 rows affected
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
Query OK, 0 rows affected
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
*************************** 1. row ***************************
Query Plan: ==========================================================
|ID|OPERATOR |NAME |EST. ROWS|COST|
----------------------------------------------------------
|0 |PX COORDINATOR | |1 |278 |
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|2 | HASH JOIN | |1 |276 |
|3 | PX PARTITION ITERATOR | |1 |92 |
|4 | TABLE SCAN |P1 |1 |92 |
|5 | EXCHANGE IN DISTR | |1 |184 |
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|7 | PX PARTITION ITERATOR | |1 |183 |
|8 | TABLE SCAN |P2 |1 |183 |
==========================================================
Outputs & filters:
-------------------------------------
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
3 - output([P1.C1], [P1.C2]), filter(nil)
4 - output([P1.C1], [P1.C2]), filter(nil),
access([P1.C1], [P1.C2]), partitions(p[0-1])
5 - output([P2.C2], [P2.C1]), filter(nil)
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
7 - output([P2.C1], [P2.C2]), filter(nil)
8 - output([P2.C1], [P2.C2]), filter(nil),
access([P2.C1], [P2.C2]), partitions(p[0-3])
1 row in set
## 执行计划形状与算子信息
在数据库系统中执行计划在内部通常是以树的形式来表示的但是不同的数据库会选择不同的方式展示给用户
如下示例分别为 PostgreSQL 数据库Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示
```sql
obclient> SELECT /*TPC-DS Q3*/ *
FROM (SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
Sum(ss_net_profit) sum_agg
FROM date_dim dt,
store_sales,
item
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
AND store_sales.ss_item_sk = item.i_item_sk
AND item.i_manufact_id = 914
AND dt.d_moy = 11
GROUP BY dt.d_year,
item.i_brand,
item.i_brand_id
ORDER BY dt.d_year,
sum_agg DESC,
brand_id)
WHERE ROWNUM <= 100;
PostgreSQL 数据库执行计划展示如下
Limit (cost=13986.86..13987.20 rows=27 width=91)
Sort (cost=13986.86..13986.93 rows=27 width=65)
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
Filter: (d_moy = 11)
Sort (cost=12170.87..12177.27 rows=2560 width=65)
Sort Key: store_sales.ss_sold_date_sk
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
Filter: (i_manufact_id = 914)
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
Recheck Cond: (ss_item_sk = item.i_item_sk)
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
Index Cond: (ss_item_sk = item.i_item_sk)
Oracle 数据库执行计划展示如下
Plan hash value: 2331821367
--------------------------------------------------------------------------------------------------
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
--------------------------------------------------------------------------------------------------
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|* 1 | COUNT STOPKEY | | | | | |
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
--------------------------------------------------------------------------------------------------
OceanBase 数据库执行计划展示如下
|ID|OPERATOR |NAME |EST. ROWS|COST |
-------------------------------------------------------
|0 |LIMIT | |100 |81141|
|1 | TOP-N SORT | |100 |81127|
|2 | HASH GROUP BY | |2924 |68551|
|3 | HASH JOIN | |2924 |65004|
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|5 | HASH GROUP BY | |2953 |18662|
|6 | NESTED-LOOP JOIN| |2953 |15080|
|7 | TABLE SCAN |ITEM |19 |11841|
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|9 | TABLE SCAN |DT |6088 |29401|
=======================================================
由示例可见OceanBase 数据库的计划展示与 Oracle 数据库类似
OceanBase 数据库执行计划中的各列的含义如下
列名 含义
ID 执行树按照前序遍历的方式得到的编号 0 开始
OPERATOR 操作算子的名称
NAME 对应表操作的表名索引名
EST. ROWS 估算该操作算子的输出行数
COST 该操作算子的执行代价微秒
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示层次最深的优先执行层次相同的以特定算子的执行顺序为标准来执行
问题: update a not exists (b)
我一开始以为 B是驱动表B的数据挺多的 后来看到NLAJ是说左边的表关联右边的表
所以这个的驱动表是不是实际是A用A的匹配B的这个理解有问题吗
回答: 没错 A 驱动 B的
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表) 选哪张成为左表或右表就基于一些其他考量了比如数据量 而anti join/semi join只是对 not exist/exist的一种优化相关的原理和资料网上可以查阅一下
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
而anti join/semi join谁在not exist左面谁就是驱动表这么理解对吧
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优
"""
from llama_index import Document
text_list = [text1]
documents = [Document(t) for t in text_list]
num_output = 250
max_input_size = 512
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper)
index.save_to_disk("index.json")
if __name__ == "__main__":
import logging
logging.getLogger().setLevel(logging.CRITICAL)
for d in documents:
print(d)
response = index.query("数据库的执行计划命令有多少?")
print(response)

View File

@ -1,3 +1 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__version__ = "0.0.1"

View File

@ -1,241 +1,18 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LLMPredictor
import torch
from langchain.llms.base import LLM
from transformers import pipeline
from langchain.agents import (
load_tools,
initialize_agent,
AgentType
)
class FlanLLM(LLM):
model_name = "google/flan-t5-large"
pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={
"torch_dtype": torch.bfloat16
})
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
llm = VicunaRequestLLM()
def _call(self, prompt, stop=None):
return self.pipeline(prompt, max_length=9999)[0]["generated_text"]
def _identifying_params(self):
return {"name_of_model": self.model_name}
def _llm_type(self):
return "custome"
llm_predictor = LLMPredictor(llm=FlanLLM())
hfemb = HuggingFaceEmbeddings()
embed_model = LangchainEmbedding(hfemb)
text1 = """
执行计划是对一条 SQL 查询语句在数据库中执行过程的描述用户可以通过 EXPLAIN 命令查看优化器针对指定 SQL 生成的逻辑执行计划
如果要分析某条 SQL 的性能问题通常需要先查看 SQL 的执行计划排查每一步 SQL 执行是否存在问题所以读懂执行计划是 SQL 优化的先决条件而了解执行计划的算子是理解 EXPLAIN 命令的关键
OceanBase 数据库的执行计划命令有三种模式EXPLAIN BASICEXPLAIN EXPLAIN EXTENDED这三种模式对执行计划展现不同粒度的细节信息:
EXPLAIN BASIC 命令用于最基本的计划展示
EXPLAIN EXTENDED 命令用于最详细的计划展示通常在排查问题时使用这种展示模式
EXPLAIN 命令所展示的信息可以帮助普通用户了解整个计划的执行方式
EXPLAIN 命令格式如下
EXPLAIN [BASIC | EXTENDED | PARTITIONS | FORMAT = format_name] [PRETTY | PRETTY_COLOR] explainable_stmt
format_name:
{ TRADITIONAL | JSON }
explainable_stmt:
{ SELECT statement
| DELETE statement
| INSERT statement
| REPLACE statement
| UPDATE statement }
EXPLAIN 命令适用于 SELECTDELETEINSERTREPLACE UPDATE 语句显示优化器所提供的有关语句执行计划的信息包括如何处理该语句如何联接表以及以何种顺序联接表等信息
一般来说可以使用 EXPLAIN EXTENDED 命令将表扫描的范围段展示出来使用 EXPLAIN OUTLINE 命令可以显示 Outline 信息
FORMAT 选项可用于选择输出格式TRADITIONAL 表示以表格格式显示输出这也是默认设置JSON 表示以 JSON 格式显示信息
使用 EXPLAIN PARTITITIONS 也可用于检查涉及分区表的查询如果检查针对非分区表的查询则不会产生错误 PARTIONS 列的值始终为 NULL
对于复杂的执行计划可以使用 PRETTY 或者 PRETTY_COLOR 选项将计划树中的父节点和子节点使用树线或彩色树线连接起来使得执行计划展示更方便阅读示例如下
obclient> CREATE TABLE p1table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 2;
Query OK, 0 rows affected
obclient> CREATE TABLE p2table(c1 INT ,c2 INT) PARTITION BY HASH(c1) PARTITIONS 4;
Query OK, 0 rows affected
obclient> EXPLAIN EXTENDED PRETTY_COLOR SELECT * FROM p1table p1 JOIN p2table p2 ON p1.c1=p2.c2\G
*************************** 1. row ***************************
Query Plan: ==========================================================
|ID|OPERATOR |NAME |EST. ROWS|COST|
----------------------------------------------------------
|0 |PX COORDINATOR | |1 |278 |
|1 | EXCHANGE OUT DISTR |:EX10001|1 |277 |
|2 | HASH JOIN | |1 |276 |
|3 | PX PARTITION ITERATOR | |1 |92 |
|4 | TABLE SCAN |P1 |1 |92 |
|5 | EXCHANGE IN DISTR | |1 |184 |
|6 | EXCHANGE OUT DISTR (PKEY)|:EX10000|1 |184 |
|7 | PX PARTITION ITERATOR | |1 |183 |
|8 | TABLE SCAN |P2 |1 |183 |
==========================================================
Outputs & filters:
-------------------------------------
0 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil)
1 - output([INTERNAL_FUNCTION(P1.C1, P1.C2, P2.C1, P2.C2)]), filter(nil), dop=1
2 - output([P1.C1], [P2.C2], [P1.C2], [P2.C1]), filter(nil),
equal_conds([P1.C1 = P2.C2]), other_conds(nil)
3 - output([P1.C1], [P1.C2]), filter(nil)
4 - output([P1.C1], [P1.C2]), filter(nil),
access([P1.C1], [P1.C2]), partitions(p[0-1])
5 - output([P2.C2], [P2.C1]), filter(nil)
6 - (#keys=1, [P2.C2]), output([P2.C2], [P2.C1]), filter(nil), dop=1
7 - output([P2.C1], [P2.C2]), filter(nil)
8 - output([P2.C1], [P2.C2]), filter(nil),
access([P2.C1], [P2.C2]), partitions(p[0-3])
1 row in set
## 执行计划形状与算子信息
在数据库系统中执行计划在内部通常是以树的形式来表示的但是不同的数据库会选择不同的方式展示给用户
如下示例分别为 PostgreSQL 数据库Oracle 数据库和 OceanBase 数据库对于 TPCDS Q3 的计划展示
```sql
obclient> SELECT /*TPC-DS Q3*/ *
FROM (SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
Sum(ss_net_profit) sum_agg
FROM date_dim dt,
store_sales,
item
WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
AND store_sales.ss_item_sk = item.i_item_sk
AND item.i_manufact_id = 914
AND dt.d_moy = 11
GROUP BY dt.d_year,
item.i_brand,
item.i_brand_id
ORDER BY dt.d_year,
sum_agg DESC,
brand_id)
WHERE ROWNUM <= 100;
PostgreSQL 数据库执行计划展示如下
Limit (cost=13986.86..13987.20 rows=27 width=91)
Sort (cost=13986.86..13986.93 rows=27 width=65)
Sort Key: dt.d_year, (sum(store_sales.ss_net_profit)), item.i_brand_id
HashAggregate (cost=13985.95..13986.22 rows=27 width=65)
Merge Join (cost=13884.21..13983.91 rows=204 width=65)
Merge Cond: (dt.d_date_sk = store_sales.ss_sold_date_sk)
Index Scan using date_dim_pkey on date_dim dt (cost=0.00..3494.62 rows=6080 width=8)
Filter: (d_moy = 11)
Sort (cost=12170.87..12177.27 rows=2560 width=65)
Sort Key: store_sales.ss_sold_date_sk
Nested Loop (cost=6.02..12025.94 rows=2560 width=65)
Seq Scan on item (cost=0.00..1455.00 rows=16 width=59)
Filter: (i_manufact_id = 914)
Bitmap Heap Scan on store_sales (cost=6.02..658.94 rows=174 width=14)
Recheck Cond: (ss_item_sk = item.i_item_sk)
Bitmap Index Scan on store_sales_pkey (cost=0.00..5.97 rows=174 width=0)
Index Cond: (ss_item_sk = item.i_item_sk)
Oracle 数据库执行计划展示如下
Plan hash value: 2331821367
--------------------------------------------------------------------------------------------------
| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time |
--------------------------------------------------------------------------------------------------
| 0 | SELECT STATEMENT | | 100 | 9100 | 3688 (1)| 00:00:01 |
|* 1 | COUNT STOPKEY | | | | | |
| 2 | VIEW | | 2736 | 243K| 3688 (1)| 00:00:01 |
|* 3 | SORT ORDER BY STOPKEY | | 2736 | 256K| 3688 (1)| 00:00:01 |
| 4 | HASH GROUP BY | | 2736 | 256K| 3688 (1)| 00:00:01 |
|* 5 | HASH JOIN | | 2736 | 256K| 3686 (1)| 00:00:01 |
|* 6 | TABLE ACCESS FULL | DATE_DIM | 6087 | 79131 | 376 (1)| 00:00:01 |
| 7 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
| 8 | NESTED LOOPS | | 2865 | 232K| 3310 (1)| 00:00:01 |
|* 9 | TABLE ACCESS FULL | ITEM | 18 | 1188 | 375 (0)| 00:00:01 |
|* 10 | INDEX RANGE SCAN | SYS_C0010069 | 159 | | 2 (0)| 00:00:01 |
| 11 | TABLE ACCESS BY INDEX ROWID| STORE_SALES | 159 | 2703 | 163 (0)| 00:00:01 |
--------------------------------------------------------------------------------------------------
OceanBase 数据库执行计划展示如下
|ID|OPERATOR |NAME |EST. ROWS|COST |
-------------------------------------------------------
|0 |LIMIT | |100 |81141|
|1 | TOP-N SORT | |100 |81127|
|2 | HASH GROUP BY | |2924 |68551|
|3 | HASH JOIN | |2924 |65004|
|4 | SUBPLAN SCAN |VIEW1 |2953 |19070|
|5 | HASH GROUP BY | |2953 |18662|
|6 | NESTED-LOOP JOIN| |2953 |15080|
|7 | TABLE SCAN |ITEM |19 |11841|
|8 | TABLE SCAN |STORE_SALES|161 |73 |
|9 | TABLE SCAN |DT |6088 |29401|
=======================================================
由示例可见OceanBase 数据库的计划展示与 Oracle 数据库类似
OceanBase 数据库执行计划中的各列的含义如下
列名 含义
ID 执行树按照前序遍历的方式得到的编号 0 开始
OPERATOR 操作算子的名称
NAME 对应表操作的表名索引名
EST. ROWS 估算该操作算子的输出行数
COST 该操作算子的执行代价微秒
OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形结构展示其中每一个操作在树中的层次通过其在 operator 中的缩进予以展示层次最深的优先执行层次相同的以特定算子的执行顺序为标准来执行
问题: update a not exists (b)
我一开始以为 B是驱动表B的数据挺多的 后来看到NLAJ是说左边的表关联右边的表
所以这个的驱动表是不是实际是A用A的匹配B的这个理解有问题吗
回答: 没错 A 驱动 B的
问题: 光知道最下最右的是驱动表了 所以一开始搞得有点懵 :sweat_smile:
回答: nlj应该原理应该都是左表(驱动表)的记录探测右表(被驱动表) 选哪张成为左表或右表就基于一些其他考量了比如数据量 而anti join/semi join只是对 not exist/exist的一种优化相关的原理和资料网上可以查阅一下
问题: 也就是nlj 就是按照之前理解的谁先执行 谁就是驱动表 也就是执行计划中的最右的表
而anti join/semi join谁在not exist左面谁就是驱动表这么理解对吧
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优
"""
from llama_index import Document
text_list = [text1]
documents = [Document(t) for t in text_list]
num_output = 250
max_input_size = 512
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper)
index.save_to_disk("index.json")
if __name__ == "__main__":
import logging
logging.getLogger().setLevel(logging.CRITICAL)
for d in documents:
print(d)
response = index.query("数据库的执行计划命令有多少?")
print(response)
tools = load_tools(['python_repl'], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
agent.run(
"Write a python script that prints 'Hello World!'"
)

View File

@ -9,9 +9,18 @@ model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
llm_model_config = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
}
LLM_MODEL = "vicuna-13b"
LLM_MODEL = "vicuna-13b"
vicuna_model_server = "http://192.168.31.114:21000/"
# Load model config
isload_8bit = True
isdebug = False

86
pilot/model/inference.py Normal file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eso_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
pos = -1
input_ids = tokenizer(prompt).input_ids
output_ids = []
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_value
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop_strings:
pos = output.rfind(stop_str)
if pos != -1:
output = output[:pos]
stoppped = True
break
else:
pass
if stoppped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids
input_embeddings = model.get_input_embeddings()
embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean

View File

@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-
import torch
from pilot.utils import get_gpu_memory
from fastchat.serve.inference import compress_module
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
@ -28,12 +26,12 @@ class ModerLoader:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
if load_8bit:
compress_module(model, self.device)
if debug:
print(model)
if self.device == "cuda":
model.to(self.device)
return model, tokenizer

View File

@ -1,9 +1,84 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from transformers import pipeline
import json
import requests
from urllib.parse import urljoin
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM
from configs.model_config import *
class VicunaLLM(LLM):
model_name = llm_model_config[LLM_MODEL]
class VicunaRequestLLM(LLM):
vicuna_generate_path = "generate"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
if isinstance(stop, list):
stop = stop + ["Observation:"]
params = {
"prompt": prompt,
"temperature": 0,
"max_new_tokens": 256,
"stop": stop
}
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
data=json.dumps(params)
)
response.raise_for_status()
return response.json()["response"]
@property
def _llm_type(self) -> str:
return "custome"
def _identifying_params(self) -> Mapping[str, Any]:
return {}
class VicunaEmbeddingLLM(BaseModel, Embeddings):
vicuna_embedding_path = "embedding"
def _call(self, prompt: str) -> str:
p = prompt.strip()
print("Sending prompt ", p)
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_embedding_path),
json={
"prompt": p
}
)
response.raise_for_status()
return response.json()["response"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
""" Call out to Vicuna's server embedding endpoint for embedding search docs.
Args:
texts: The list of text to embed
Returns:
List of embeddings. one for each text.
"""
results = []
for text in texts:
response = self.embed_query(text)
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
""" Call out to Vicuna's server embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text
"""
embedding = self._call(text)
return embedding

View File

@ -1,3 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@ -7,7 +7,6 @@ import torch
import gradio as gr
from fastchat.serve.inference import generate_stream, compress_module
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
@ -21,12 +20,12 @@ model = AutoModelForCausalLM.from_pretrained(
)
def generate(prompt):
# compress_module(model, device)
# model.to(device)
compress_module(model, device)
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": prompt,
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
@ -36,9 +35,6 @@ def generate(prompt):
for chunk in output:
yield chunk
#for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
# if chunk:
# yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo:
@ -53,5 +49,3 @@ if __name__ == "__main__":
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from fastapi import FastAPI
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
stop: Optional(List[str]) = None
class EmbeddingRequest(BaseModel):
prompt: str
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)
return {"response": output}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]}