diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 675c51b66..a9436276f 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi MODEL_PATH = os.path.join(ROOT_PATH, "models") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") - +DATASETS_DIR = os.path.join(ROOT_PATH, "datasets") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LLM_MODEL_CONFIG = { diff --git a/pilot/datasets/__init__.py b/pilot/datasets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 91b51c10d..8154ea99d 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,7 +12,7 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.vector_store.extract_tovec import get_vector_storelist +from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL @@ -48,6 +48,16 @@ priority = { "vicuna-13b": "aaa" } +def get_simlar(q): + + docsearch = load_knownledge_from_doc() + docs = docsearch.similarity_search_with_score(q, k=1) + + contents = [dc.page_content for dc, _ in docs] + return "\n".join(contents) + + + def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -150,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return + query = state.messages[-2][1] if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -158,11 +169,15 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.conv_id = uuid.uuid4().hex # prompt 中添加上下文提示 - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1]) + if db_selector: + new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message(new_state.roles[1], None) state = new_state - + if not db_selector: + state.append_message(new_state.roles[0], get_simlar(query) + query) + prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -237,6 +252,9 @@ pre { """ ) +def change_tab(tab): + pass + def change_mode(mode): if mode == "默认知识库对话": return gr.update(visible=False) @@ -256,7 +274,6 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ - vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State() state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -278,10 +295,10 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - tabs = gr.Tabs() + tabs = gr.Tabs() with tabs: with gr.TabItem("SQL生成与诊断", elem_id="SQL"): - # TODO A selector to choose database + # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", @@ -289,9 +306,8 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) - - with gr.TabItem("知识问答", elem_id="QA"): + with gr.TabItem("知识问答", elem_id="QA"): mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) @@ -331,9 +347,6 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) - # QA 模式下清空数据库选项 - if tabs.elem_id == "QA": - db_selector = "" gr.Markdown(learn_more_markdown) diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index ccfe7bda0..223ff90c8 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -6,7 +6,7 @@ import os from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.model.vicuna_llm import VicunaEmbeddingLLM -from pilot.configs.model_config import VECTORE_PATH +from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() @@ -14,7 +14,7 @@ embeddings = VicunaEmbeddingLLM() def knownledge_tovec(filename): with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts( @@ -38,6 +38,33 @@ def knownledge_tovec_st(filename): docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) return docsearch + +def load_knownledge_from_doc(): + """从数据集当中加载知识 + # TODO 如果向量存储已经存在, 则无需初始化 + """ + + if not os.path.exists(DATASETS_DIR): + print("Not Exists Local DataSets, We will answers the Question use model default.") + + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + + docs = [] + files = os.listdir(DATASETS_DIR) + for file in files: + if not os.path.isdir(file): + with open(file, "r") as f: + doc = f.read() + docs.append(docs) + + print(doc) + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) + texts = text_splitter.split_text("\n".join(docs)) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], + persist_directory=os.path.join(VECTORE_PATH, ".vectore")) + return docsearch + def get_vector_storelist(): if not os.path.exists(VECTORE_PATH): return []