diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 4468082d8..4bf0d0eb4 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -236,6 +236,12 @@ pre { """ ) +def change_mode(mode): + if mode == "默认知识库对话": + return gr.update(visible=False) + else: + return gr.update(visible=True) + def build_single_model_ui(): @@ -249,6 +255,7 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ + vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State(vs_list) state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -270,10 +277,16 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - - with gr.Tabs(): + tabs = gr.Tabs() + with tabs: with gr.TabItem("知识问答", elem_id="QA"): - pass + doc2vec = gr.Column(visible=False) + with doc2vec: + mode = gr.Radio(["默认知识库对话", "新增知识库"]) + vs_setting = gr.Accordion("配置知识库") + mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + with vs_setting: + select_vs = gr.Dropdown() with gr.TabItem("SQL生成与诊断", elem_id="SQL"): # TODO A selector to choose database with gr.Row(elem_id="db_selector"): @@ -300,6 +313,10 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) + # QA 模式下清空数据库选项 + if tabs.elem_id == "QA": + db_selector = "" + gr.Markdown(learn_more_markdown) btn_list = [regenerate_btn, clear_btn] diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 5b7df3eb2..2581b0264 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import os from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.model.vicuna_llm import VicunaEmbeddingLLM -# from langchain.embeddings import SentenceTransformerEmbeddings - +from pilot.configs.model_config import VECTORE_PATH +from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() @@ -21,18 +22,22 @@ def knownledge_tovec(filename): ) return docsearch +def knownledge_tovec_st(filename): + """ Use sentence transformers to embedding the document. + https://github.com/UKPLab/sentence-transformers + """ + from pilot.configs.model_config import llm_model_config + embeddings = HuggingFaceEmbeddings(model=llm_model_config["sentence-transforms"]) -# def knownledge_tovec_st(filename): -# """ Use sentence transformers to embedding the document. -# https://github.com/UKPLab/sentence-transformers -# """ -# from pilot.configs.model_config import llm_model_config -# embeddings = SentenceTransformerEmbeddings(model=llm_model_config["sentence-transforms"]) - -# with open(filename, "r") as f: -# knownledge = f.read() + with open(filename, "r") as f: + knownledge = f.read() -# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) -# texts = text_splitter(knownledge) -# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) -# return docsearch + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter(knownledge) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) + return docsearch + +def get_vector_storelist(): + if not os.path.exists(VECTORE_PATH): + return [] + return os.listdir(VECTORE_PATH) \ No newline at end of file