From a164d2f156833ba475d8ed2e190c65a5bbf4737a Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 20:06:24 +0800 Subject: [PATCH 1/9] update config file --- pilot/configs/model_config.py | 24 ++++++++++++------------ pilot/model/vicuna_llm.py | 4 ++-- pilot/server/webserver.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 1238d1bcb..675c51b66 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -4,34 +4,34 @@ import torch import os -root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -model_path = os.path.join(root_path, "models") -vector_storepath = os.path.join(root_path, "vector_store") -LOGDIR = os.path.join(root_path, "logs") +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +MODEL_PATH = os.path.join(ROOT_PATH, "models") +VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") +LOGDIR = os.path.join(ROOT_PATH, "logs") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -llm_model_config = { - "flan-t5-base": os.path.join(model_path, "flan-t5-base"), - "vicuna-13b": os.path.join(model_path, "vicuna-13b"), - "sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2") +LLM_MODEL_CONFIG = { + "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), + "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") } LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 -vicuna_model_server = "http://192.168.31.114:8000" +VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" # Load model config -isload_8bit = True -isdebug = False +ISLOAD_8BIT = True +ISDEBUG = False DB_SETTINGS = { "user": "root", - "password": "********", + "password": "aa123456", "host": "localhost", "port": 3306 } \ No newline at end of file diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 26673344f..eba2834ae 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM): "stop": stop } response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_generate_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) response.raise_for_status() @@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings): print("Sending prompt ", p) response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_embedding_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path), json={ "prompt": p } diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c13a5331f..4468082d8 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,7 +14,7 @@ from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL +from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL from pilot.conversation import ( default_conversation, @@ -181,7 +181,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques try: # Stream output - response = requests.post(urljoin(vicuna_model_server, "generate_stream"), + response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: From b7b4a1fb63ba8a00cda7fbf3302cd15ff358bfca Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 20:21:39 +0800 Subject: [PATCH 2/9] update --- pilot/server/vicuna_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index dba68699e..79bc1dab3 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -14,7 +14,7 @@ from fastchat.serve.inference import load_model from pilot.model.loader import ModerLoader from pilot.configs.model_config import * -model_path = llm_model_config[LLM_MODEL] +model_path = LLM_MODEL_CONFIG[LLM_MODEL] global_counter = 0 From 205eab72683210244ed2a63e7a3b25bbb637c257 Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 21:40:23 +0800 Subject: [PATCH 3/9] sentence transformer --- pilot/server/webserver.py | 23 ++++++++++++++++--- pilot/vector_store/extract_tovec.py | 35 ++++++++++++++++------------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 4468082d8..4bf0d0eb4 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -236,6 +236,12 @@ pre { """ ) +def change_mode(mode): + if mode == "默认知识库对话": + return gr.update(visible=False) + else: + return gr.update(visible=True) + def build_single_model_ui(): @@ -249,6 +255,7 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ + vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State(vs_list) state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -270,10 +277,16 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - - with gr.Tabs(): + tabs = gr.Tabs() + with tabs: with gr.TabItem("知识问答", elem_id="QA"): - pass + doc2vec = gr.Column(visible=False) + with doc2vec: + mode = gr.Radio(["默认知识库对话", "新增知识库"]) + vs_setting = gr.Accordion("配置知识库") + mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + with vs_setting: + select_vs = gr.Dropdown() with gr.TabItem("SQL生成与诊断", elem_id="SQL"): # TODO A selector to choose database with gr.Row(elem_id="db_selector"): @@ -300,6 +313,10 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) + # QA 模式下清空数据库选项 + if tabs.elem_id == "QA": + db_selector = "" + gr.Markdown(learn_more_markdown) btn_list = [regenerate_btn, clear_btn] diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 5b7df3eb2..2581b0264 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import os from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.model.vicuna_llm import VicunaEmbeddingLLM -# from langchain.embeddings import SentenceTransformerEmbeddings - +from pilot.configs.model_config import VECTORE_PATH +from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() @@ -21,18 +22,22 @@ def knownledge_tovec(filename): ) return docsearch +def knownledge_tovec_st(filename): + """ Use sentence transformers to embedding the document. + https://github.com/UKPLab/sentence-transformers + """ + from pilot.configs.model_config import llm_model_config + embeddings = HuggingFaceEmbeddings(model=llm_model_config["sentence-transforms"]) -# def knownledge_tovec_st(filename): -# """ Use sentence transformers to embedding the document. -# https://github.com/UKPLab/sentence-transformers -# """ -# from pilot.configs.model_config import llm_model_config -# embeddings = SentenceTransformerEmbeddings(model=llm_model_config["sentence-transforms"]) - -# with open(filename, "r") as f: -# knownledge = f.read() + with open(filename, "r") as f: + knownledge = f.read() -# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) -# texts = text_splitter(knownledge) -# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) -# return docsearch + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter(knownledge) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) + return docsearch + +def get_vector_storelist(): + if not os.path.exists(VECTORE_PATH): + return [] + return os.listdir(VECTORE_PATH) \ No newline at end of file From 7da910b6428f205e0ab86fbb73df7b4f44ec9ccb Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 22:14:28 +0800 Subject: [PATCH 4/9] add gradio tem --- pilot/app.py | 4 +-- pilot/model/vicuna_llm.py | 2 +- pilot/server/webserver.py | 38 +++++++++++++++++++++-------- pilot/vector_store/extract_tovec.py | 7 +++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/pilot/app.py b/pilot/app.py index 5456621f2..b3f8d6ab1 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,8 +40,8 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec - docsearch = knownledge_tovec("./datasets/plan.md") + from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st + docsearch = knownledge_tovec_st("./datasets/plan.md") docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index eba2834ae..f17a17a00 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings from pydantic import BaseModel from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM -from configs.model_config import * +from pilot.configs.model_config import * class VicunaRequestLLM(LLM): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 4bf0d0eb4..4a107e791 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,7 +12,7 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator - +from pilot.vector_store.extract_tovec import get_vector_storelist from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL @@ -42,6 +42,7 @@ disable_btn = gr.Button.update(interactive=True) enable_moderation = False models = [] dbs = [] +vs_list = ["新建知识库"] + get_vector_storelist() priority = { "vicuna-13b": "aaa" @@ -255,7 +256,7 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ - vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State(vs_list) + vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State() state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -279,14 +280,6 @@ def build_single_model_ui(): ) tabs = gr.Tabs() with tabs: - with gr.TabItem("知识问答", elem_id="QA"): - doc2vec = gr.Column(visible=False) - with doc2vec: - mode = gr.Radio(["默认知识库对话", "新增知识库"]) - vs_setting = gr.Accordion("配置知识库") - mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - with vs_setting: - select_vs = gr.Dropdown() with gr.TabItem("SQL生成与诊断", elem_id="SQL"): # TODO A selector to choose database with gr.Row(elem_id="db_selector"): @@ -296,6 +289,31 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) + + with gr.TabItem("知识问答", elem_id="QA"): + + mode = gr.Radio(["默认知识库对话", "新增知识库"]) + vs_setting = gr.Accordion("配置知识库") + mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + with vs_setting: + vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True) + vs_add = gr.Button("添加为新知识库") + with gr.Column() as doc2vec: + gr.Markdown("向知识库中添加文件") + with gr.Tab("上传文件"): + files = gr.File(label="添加文件", + file_types=[".txt", ".md", ".docx", ".pdf"], + file_count="multiple", + show_label=False + ) + + load_file_button = gr.Button("上传并加载到知识库") + with gr.Tab("上传文件夹"): + folder_files = gr.File(label="添加文件", + file_count="directory", + show_label=False) + load_folder_button = gr.Button("上传并加载到知识库") + with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 2581b0264..e571ac54f 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -26,13 +26,14 @@ def knownledge_tovec_st(filename): """ Use sentence transformers to embedding the document. https://github.com/UKPLab/sentence-transformers """ - from pilot.configs.model_config import llm_model_config - embeddings = HuggingFaceEmbeddings(model=llm_model_config["sentence-transforms"]) + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter(knownledge) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) return docsearch From d91a0e7d56025f376fe7e035e41f8e944d4841c1 Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 22:29:33 +0800 Subject: [PATCH 5/9] update requirements --- pilot/vector_store/extract_tovec.py | 2 +- requirements.txt | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index e571ac54f..ccfe7bda0 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -34,7 +34,7 @@ def knownledge_tovec_st(filename): text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - texts = text_splitter(knownledge) + texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) return docsearch diff --git a/requirements.txt b/requirements.txt index 50354b3b4..1dd64f2d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,4 +50,5 @@ notebook gradio==3.24.1 gradio-client==0.0.8 wandb -fschat=0.1.10 \ No newline at end of file +fschat=0.1.10 +llama-index=0.5.27 \ No newline at end of file From 5ced79fc9af2964534bffeb368c4a27293ae7307 Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 22:43:22 +0800 Subject: [PATCH 6/9] update requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1dd64f2d0..f0ddf8fb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,5 @@ gradio==3.24.1 gradio-client==0.0.8 wandb fschat=0.1.10 -llama-index=0.5.27 \ No newline at end of file +llama-index=0.5.27 +pymysql \ No newline at end of file From 1e9934eb03f253ec5cb501b53d09aa6afe6b6a1e Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 22:44:21 +0800 Subject: [PATCH 7/9] update --- pilot/server/webserver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 4a107e791..91b51c10d 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -292,8 +292,8 @@ def build_single_model_ui(): with gr.TabItem("知识问答", elem_id="QA"): - mode = gr.Radio(["默认知识库对话", "新增知识库"]) - vs_setting = gr.Accordion("配置知识库") + mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") + vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) with vs_setting: vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True) From 529f07740907b22dd02016f03002db841bd99344 Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 23:42:51 +0800 Subject: [PATCH 8/9] load base knownledge --- pilot/configs/model_config.py | 2 +- pilot/datasets/__init__.py | 0 pilot/server/webserver.py | 35 ++++++++++++++++++++--------- pilot/vector_store/extract_tovec.py | 31 +++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 14 deletions(-) delete mode 100644 pilot/datasets/__init__.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 675c51b66..a9436276f 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi MODEL_PATH = os.path.join(ROOT_PATH, "models") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") - +DATASETS_DIR = os.path.join(ROOT_PATH, "datasets") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LLM_MODEL_CONFIG = { diff --git a/pilot/datasets/__init__.py b/pilot/datasets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 91b51c10d..8154ea99d 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,7 +12,7 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.vector_store.extract_tovec import get_vector_storelist +from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL @@ -48,6 +48,16 @@ priority = { "vicuna-13b": "aaa" } +def get_simlar(q): + + docsearch = load_knownledge_from_doc() + docs = docsearch.similarity_search_with_score(q, k=1) + + contents = [dc.page_content for dc, _ in docs] + return "\n".join(contents) + + + def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -150,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return + query = state.messages[-2][1] if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -158,11 +169,15 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.conv_id = uuid.uuid4().hex # prompt 中添加上下文提示 - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1]) + if db_selector: + new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message(new_state.roles[1], None) state = new_state - + if not db_selector: + state.append_message(new_state.roles[0], get_simlar(query) + query) + prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -237,6 +252,9 @@ pre { """ ) +def change_tab(tab): + pass + def change_mode(mode): if mode == "默认知识库对话": return gr.update(visible=False) @@ -256,7 +274,6 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ - vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State() state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -278,10 +295,10 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - tabs = gr.Tabs() + tabs = gr.Tabs() with tabs: with gr.TabItem("SQL生成与诊断", elem_id="SQL"): - # TODO A selector to choose database + # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", @@ -289,9 +306,8 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) - - with gr.TabItem("知识问答", elem_id="QA"): + with gr.TabItem("知识问答", elem_id="QA"): mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) @@ -331,9 +347,6 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) - # QA 模式下清空数据库选项 - if tabs.elem_id == "QA": - db_selector = "" gr.Markdown(learn_more_markdown) diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index ccfe7bda0..223ff90c8 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -6,7 +6,7 @@ import os from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.model.vicuna_llm import VicunaEmbeddingLLM -from pilot.configs.model_config import VECTORE_PATH +from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() @@ -14,7 +14,7 @@ embeddings = VicunaEmbeddingLLM() def knownledge_tovec(filename): with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts( @@ -38,6 +38,33 @@ def knownledge_tovec_st(filename): docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) return docsearch + +def load_knownledge_from_doc(): + """从数据集当中加载知识 + # TODO 如果向量存储已经存在, 则无需初始化 + """ + + if not os.path.exists(DATASETS_DIR): + print("Not Exists Local DataSets, We will answers the Question use model default.") + + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + + docs = [] + files = os.listdir(DATASETS_DIR) + for file in files: + if not os.path.isdir(file): + with open(file, "r") as f: + doc = f.read() + docs.append(docs) + + print(doc) + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) + texts = text_splitter.split_text("\n".join(docs)) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], + persist_directory=os.path.join(VECTORE_PATH, ".vectore")) + return docsearch + def get_vector_storelist(): if not os.path.exists(VECTORE_PATH): return [] From c1758f030b50605a6e889577400c6000b01b36b3 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 6 May 2023 00:41:35 +0800 Subject: [PATCH 9/9] knownledge based qa --- pilot/app.py | 4 ++-- pilot/configs/model_config.py | 2 +- pilot/server/webserver.py | 24 ++++++++++++++++-------- pilot/vector_store/extract_tovec.py | 10 ++++------ 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/pilot/app.py b/pilot/app.py index b3f8d6ab1..1cbcd79b0 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,8 +40,8 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st - docsearch = knownledge_tovec_st("./datasets/plan.md") + from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc + docsearch = load_knownledge_from_doc() docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index a9436276f..4527fafe0 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi MODEL_PATH = os.path.join(ROOT_PATH, "models") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") -DATASETS_DIR = os.path.join(ROOT_PATH, "datasets") +DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LLM_MODEL_CONFIG = { diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8154ea99d..b7a0ef816 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,9 +12,9 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc +from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st -from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL +from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR from pilot.conversation import ( default_conversation, @@ -50,7 +50,7 @@ priority = { def get_simlar(q): - docsearch = load_knownledge_from_doc() + docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) contents = [dc.page_content for dc, _ in docs] @@ -171,12 +171,20 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques # prompt 中添加上下文提示 if db_selector: new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message(new_state.roles[1], None) + state = new_state + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state - new_state.append_message(new_state.roles[1], None) - state = new_state - - if not db_selector: - state.append_message(new_state.roles[0], get_simlar(query) + query) + try: + if not db_selector: + sim_q = get_simlar(query) + print("********vector similar info*************: ", sim_q) + state.append_message(new_state.roles[0], sim_q + query) + except Exception as e: + print(e) prompt = state.get_prompt() diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 223ff90c8..8badf6fed 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -50,17 +50,15 @@ def load_knownledge_from_doc(): from pilot.configs.model_config import LLM_MODEL_CONFIG embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) - docs = [] files = os.listdir(DATASETS_DIR) for file in files: if not os.path.isdir(file): - with open(file, "r") as f: - doc = f.read() - docs.append(docs) + filename = os.path.join(DATASETS_DIR, file) + with open(filename, "r") as f: + knownledge = f.read() - print(doc) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) - texts = text_splitter.split_text("\n".join(docs)) + texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], persist_directory=os.path.join(VECTORE_PATH, ".vectore")) return docsearch