diff --git a/pilot/app.py b/pilot/app.py index 5456621f2..1cbcd79b0 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,8 +40,8 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec - docsearch = knownledge_tovec("./datasets/plan.md") + from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc + docsearch = load_knownledge_from_doc() docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 1238d1bcb..4527fafe0 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -4,34 +4,34 @@ import torch import os -root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -model_path = os.path.join(root_path, "models") -vector_storepath = os.path.join(root_path, "vector_store") -LOGDIR = os.path.join(root_path, "logs") - +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +MODEL_PATH = os.path.join(ROOT_PATH, "models") +VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") +LOGDIR = os.path.join(ROOT_PATH, "logs") +DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -llm_model_config = { - "flan-t5-base": os.path.join(model_path, "flan-t5-base"), - "vicuna-13b": os.path.join(model_path, "vicuna-13b"), - "sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2") +LLM_MODEL_CONFIG = { + "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), + "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") } LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 -vicuna_model_server = "http://192.168.31.114:8000" +VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" # Load model config -isload_8bit = True -isdebug = False +ISLOAD_8BIT = True +ISDEBUG = False DB_SETTINGS = { "user": "root", - "password": "********", + "password": "aa123456", "host": "localhost", "port": 3306 } \ No newline at end of file diff --git a/pilot/datasets/__init__.py b/pilot/datasets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 26673344f..f17a17a00 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings from pydantic import BaseModel from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM -from configs.model_config import * +from pilot.configs.model_config import * class VicunaRequestLLM(LLM): @@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM): "stop": stop } response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_generate_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) response.raise_for_status() @@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings): print("Sending prompt ", p) response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_embedding_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path), json={ "prompt": p } diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index dba68699e..79bc1dab3 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -14,7 +14,7 @@ from fastchat.serve.inference import load_model from pilot.model.loader import ModerLoader from pilot.configs.model_config import * -model_path = llm_model_config[LLM_MODEL] +model_path = LLM_MODEL_CONFIG[LLM_MODEL] global_counter = 0 diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c13a5331f..b7a0ef816 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,9 +12,9 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator +from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st - -from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL +from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR from pilot.conversation import ( default_conversation, @@ -42,11 +42,22 @@ disable_btn = gr.Button.update(interactive=True) enable_moderation = False models = [] dbs = [] +vs_list = ["新建知识库"] + get_vector_storelist() priority = { "vicuna-13b": "aaa" } +def get_simlar(q): + + docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) + docs = docsearch.similarity_search_with_score(q, k=1) + + contents = [dc.page_content for dc, _ in docs] + return "\n".join(contents) + + + def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -149,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return + query = state.messages[-2][1] if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -157,11 +169,23 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.conv_id = uuid.uuid4().hex # prompt 中添加上下文提示 - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1]) - new_state.append_message(new_state.roles[1], None) - state = new_state - - + if db_selector: + new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message(new_state.roles[1], None) + state = new_state + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state + + try: + if not db_selector: + sim_q = get_simlar(query) + print("********vector similar info*************: ", sim_q) + state.append_message(new_state.roles[0], sim_q + query) + except Exception as e: + print(e) + prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -181,7 +205,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques try: # Stream output - response = requests.post(urljoin(vicuna_model_server, "generate_stream"), + response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: @@ -236,6 +260,15 @@ pre { """ ) +def change_tab(tab): + pass + +def change_mode(mode): + if mode == "默认知识库对话": + return gr.update(visible=False) + else: + return gr.update(visible=True) + def build_single_model_ui(): @@ -270,12 +303,10 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - - with gr.Tabs(): - with gr.TabItem("知识问答", elem_id="QA"): - pass + tabs = gr.Tabs() + with tabs: with gr.TabItem("SQL生成与诊断", elem_id="SQL"): - # TODO A selector to choose database + # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", @@ -283,6 +314,30 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) + + with gr.TabItem("知识问答", elem_id="QA"): + mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") + vs_setting = gr.Accordion("配置知识库", open=False) + mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + with vs_setting: + vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True) + vs_add = gr.Button("添加为新知识库") + with gr.Column() as doc2vec: + gr.Markdown("向知识库中添加文件") + with gr.Tab("上传文件"): + files = gr.File(label="添加文件", + file_types=[".txt", ".md", ".docx", ".pdf"], + file_count="multiple", + show_label=False + ) + + load_file_button = gr.Button("上传并加载到知识库") + with gr.Tab("上传文件夹"): + folder_files = gr.File(label="添加文件", + file_count="directory", + show_label=False) + load_folder_button = gr.Button("上传并加载到知识库") + with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) @@ -300,6 +355,7 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) + gr.Markdown(learn_more_markdown) btn_list = [regenerate_btn, clear_btn] diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 5b7df3eb2..8badf6fed 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -1,19 +1,20 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import os from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.model.vicuna_llm import VicunaEmbeddingLLM -# from langchain.embeddings import SentenceTransformerEmbeddings - +from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR +from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() def knownledge_tovec(filename): with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts( @@ -21,18 +22,48 @@ def knownledge_tovec(filename): ) return docsearch +def knownledge_tovec_st(filename): + """ Use sentence transformers to embedding the document. + https://github.com/UKPLab/sentence-transformers + """ + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) -# def knownledge_tovec_st(filename): -# """ Use sentence transformers to embedding the document. -# https://github.com/UKPLab/sentence-transformers -# """ -# from pilot.configs.model_config import llm_model_config -# embeddings = SentenceTransformerEmbeddings(model=llm_model_config["sentence-transforms"]) + with open(filename, "r") as f: + knownledge = f.read() + + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) -# with open(filename, "r") as f: -# knownledge = f.read() - -# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) -# texts = text_splitter(knownledge) -# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) -# return docsearch + texts = text_splitter.split_text(knownledge) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) + return docsearch + + +def load_knownledge_from_doc(): + """从数据集当中加载知识 + # TODO 如果向量存储已经存在, 则无需初始化 + """ + + if not os.path.exists(DATASETS_DIR): + print("Not Exists Local DataSets, We will answers the Question use model default.") + + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + + files = os.listdir(DATASETS_DIR) + for file in files: + if not os.path.isdir(file): + filename = os.path.join(DATASETS_DIR, file) + with open(filename, "r") as f: + knownledge = f.read() + + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) + texts = text_splitter.split_text(knownledge) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], + persist_directory=os.path.join(VECTORE_PATH, ".vectore")) + return docsearch + +def get_vector_storelist(): + if not os.path.exists(VECTORE_PATH): + return [] + return os.listdir(VECTORE_PATH) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 50354b3b4..f0ddf8fb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,4 +50,6 @@ notebook gradio==3.24.1 gradio-client==0.0.8 wandb -fschat=0.1.10 \ No newline at end of file +fschat=0.1.10 +llama-index=0.5.27 +pymysql \ No newline at end of file