Merge pull request #13 from csunny/dev

qa based knowledge
This commit is contained in:
magic.chen 2023-05-06 00:43:15 +08:00 committed by GitHub
commit 3db8e33f20
8 changed files with 138 additions and 49 deletions

View File

@ -40,8 +40,8 @@ def get_answer(q):
return response.response return response.response
def get_similar(q): def get_similar(q):
from pilot.vector_store.extract_tovec import knownledge_tovec from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc
docsearch = knownledge_tovec("./datasets/plan.md") docsearch = load_knownledge_from_doc()
docs = docsearch.similarity_search_with_score(q, k=1) docs = docsearch.similarity_search_with_score(q, k=1)
for doc in docs: for doc in docs:

View File

@ -4,34 +4,34 @@
import torch import torch
import os import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models") MODEL_PATH = os.path.join(ROOT_PATH, "models")
vector_storepath = os.path.join(root_path, "vector_store") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
LOGDIR = os.path.join(root_path, "logs") LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
llm_model_config = { LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"), "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2") "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
} }
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048 MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000" VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
# Load model config # Load model config
isload_8bit = True ISLOAD_8BIT = True
isdebug = False ISDEBUG = False
DB_SETTINGS = { DB_SETTINGS = {
"user": "root", "user": "root",
"password": "********", "password": "aa123456",
"host": "localhost", "host": "localhost",
"port": 3306 "port": 3306
} }

View File

@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Mapping, Optional, List from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from configs.model_config import * from pilot.configs.model_config import *
class VicunaRequestLLM(LLM): class VicunaRequestLLM(LLM):
@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM):
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
response.raise_for_status() response.raise_for_status()
@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p) print("Sending prompt ", p)
response = requests.post( response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_embedding_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
json={ json={
"prompt": p "prompt": p
} }

View File

@ -14,7 +14,7 @@ from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader from pilot.model.loader import ModerLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL] model_path = LLM_MODEL_CONFIG[LLM_MODEL]
global_counter = 0 global_counter = 0

View File

@ -12,9 +12,9 @@ import requests
from urllib.parse import urljoin from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import ( from pilot.conversation import (
default_conversation, default_conversation,
@ -42,11 +42,22 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False enable_moderation = False
models = [] models = []
dbs = [] dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
priority = { priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
} }
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname): def gen_sqlgen_conversation(dbname):
mo = MySQLOperator( mo = MySQLOperator(
**DB_SETTINGS **DB_SETTINGS
@ -149,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return return
query = state.messages[-2][1]
if len(state.messages) == state.offset + 2: if len(state.messages) == state.offset + 2:
# 第一轮对话需要加入提示Prompt # 第一轮对话需要加入提示Prompt
@ -157,11 +169,23 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
new_state.conv_id = uuid.uuid4().hex new_state.conv_id = uuid.uuid4().hex
# prompt 中添加上下文提示 # prompt 中添加上下文提示
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1]) if db_selector:
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
state = new_state new_state.append_message(new_state.roles[1], None)
state = new_state
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
try:
if not db_selector:
sim_q = get_simlar(query)
print("********vector similar info*************: ", sim_q)
state.append_message(new_state.roles[0], sim_q + query)
except Exception as e:
print(e)
prompt = state.get_prompt() prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1 skip_echo_len = len(prompt.replace("</s>", " ")) + 1
@ -181,7 +205,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
try: try:
# Stream output # Stream output
response = requests.post(urljoin(vicuna_model_server, "generate_stream"), response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20) headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
@ -236,6 +260,15 @@ pre {
""" """
) )
def change_tab(tab):
pass
def change_mode(mode):
if mode == "默认知识库对话":
return gr.update(visible=False)
else:
return gr.update(visible=True)
def build_single_model_ui(): def build_single_model_ui():
@ -270,12 +303,10 @@ def build_single_model_ui():
interactive=True, interactive=True,
label="最大输出Token数", label="最大输出Token数",
) )
tabs = gr.Tabs()
with gr.Tabs(): with tabs:
with gr.TabItem("知识问答", elem_id="QA"):
pass
with gr.TabItem("SQL生成与诊断", elem_id="SQL"): with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database # TODO A selector to choose database
with gr.Row(elem_id="db_selector"): with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown( db_selector = gr.Dropdown(
label="请选择数据库", label="请选择数据库",
@ -283,6 +314,30 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "", value=dbs[0] if len(models) > 0 else "",
interactive=True, interactive=True,
show_label=True).style(container=False) show_label=True).style(container=False)
with gr.TabItem("知识问答", elem_id="QA"):
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
vs_add = gr.Button("添加为新知识库")
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
show_label=False
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件",
file_count="directory",
show_label=False)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks(): with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
@ -300,6 +355,7 @@ def build_single_model_ui():
regenerate_btn = gr.Button(value="重新生成", interactive=False) regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False) clear_btn = gr.Button(value="清理", interactive=False)
gr.Markdown(learn_more_markdown) gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]

View File

@ -1,19 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import os
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from pilot.model.vicuna_llm import VicunaEmbeddingLLM from pilot.model.vicuna_llm import VicunaEmbeddingLLM
# from langchain.embeddings import SentenceTransformerEmbeddings from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = VicunaEmbeddingLLM() embeddings = VicunaEmbeddingLLM()
def knownledge_tovec(filename): def knownledge_tovec(filename):
with open(filename, "r") as f: with open(filename, "r") as f:
knownledge = f.read() knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(knownledge) texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts( docsearch = Chroma.from_texts(
@ -21,18 +22,48 @@ def knownledge_tovec(filename):
) )
return docsearch return docsearch
def knownledge_tovec_st(filename):
""" Use sentence transformers to embedding the document.
https://github.com/UKPLab/sentence-transformers
"""
from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
# def knownledge_tovec_st(filename): with open(filename, "r") as f:
# """ Use sentence transformers to embedding the document. knownledge = f.read()
# https://github.com/UKPLab/sentence-transformers
# """ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# from pilot.configs.model_config import llm_model_config
# embeddings = SentenceTransformerEmbeddings(model=llm_model_config["sentence-transforms"])
# with open(filename, "r") as f: texts = text_splitter.split_text(knownledge)
# knownledge = f.read() docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
return docsearch
# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# texts = text_splitter(knownledge)
# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) def load_knownledge_from_doc():
# return docsearch """从数据集当中加载知识
# TODO 如果向量存储已经存在, 则无需初始化
"""
if not os.path.exists(DATASETS_DIR):
print("Not Exists Local DataSets, We will answers the Question use model default.")
from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
files = os.listdir(DATASETS_DIR)
for file in files:
if not os.path.isdir(file):
filename = os.path.join(DATASETS_DIR, file)
with open(filename, "r") as f:
knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
return docsearch
def get_vector_storelist():
if not os.path.exists(VECTORE_PATH):
return []
return os.listdir(VECTORE_PATH)

View File

@ -50,4 +50,6 @@ notebook
gradio==3.24.1 gradio==3.24.1
gradio-client==0.0.8 gradio-client==0.0.8
wandb wandb
fschat=0.1.10 fschat=0.1.10
llama-index=0.5.27
pymysql