From 7da910b6428f205e0ab86fbb73df7b4f44ec9ccb Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 5 May 2023 22:14:28 +0800 Subject: [PATCH] add gradio tem --- pilot/app.py | 4 +-- pilot/model/vicuna_llm.py | 2 +- pilot/server/webserver.py | 38 +++++++++++++++++++++-------- pilot/vector_store/extract_tovec.py | 7 +++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/pilot/app.py b/pilot/app.py index 5456621f2..b3f8d6ab1 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,8 +40,8 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec - docsearch = knownledge_tovec("./datasets/plan.md") + from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st + docsearch = knownledge_tovec_st("./datasets/plan.md") docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index eba2834ae..f17a17a00 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings from pydantic import BaseModel from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM -from configs.model_config import * +from pilot.configs.model_config import * class VicunaRequestLLM(LLM): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 4bf0d0eb4..4a107e791 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,7 +12,7 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator - +from pilot.vector_store.extract_tovec import get_vector_storelist from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL @@ -42,6 +42,7 @@ disable_btn = gr.Button.update(interactive=True) enable_moderation = False models = [] dbs = [] +vs_list = ["新建知识库"] + get_vector_storelist() priority = { "vicuna-13b": "aaa" @@ -255,7 +256,7 @@ def build_single_model_ui(): The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA """ - vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State(vs_list) + vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State() state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -279,14 +280,6 @@ def build_single_model_ui(): ) tabs = gr.Tabs() with tabs: - with gr.TabItem("知识问答", elem_id="QA"): - doc2vec = gr.Column(visible=False) - with doc2vec: - mode = gr.Radio(["默认知识库对话", "新增知识库"]) - vs_setting = gr.Accordion("配置知识库") - mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - with vs_setting: - select_vs = gr.Dropdown() with gr.TabItem("SQL生成与诊断", elem_id="SQL"): # TODO A selector to choose database with gr.Row(elem_id="db_selector"): @@ -296,6 +289,31 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) + + with gr.TabItem("知识问答", elem_id="QA"): + + mode = gr.Radio(["默认知识库对话", "新增知识库"]) + vs_setting = gr.Accordion("配置知识库") + mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + with vs_setting: + vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True) + vs_add = gr.Button("添加为新知识库") + with gr.Column() as doc2vec: + gr.Markdown("向知识库中添加文件") + with gr.Tab("上传文件"): + files = gr.File(label="添加文件", + file_types=[".txt", ".md", ".docx", ".pdf"], + file_count="multiple", + show_label=False + ) + + load_file_button = gr.Button("上传并加载到知识库") + with gr.Tab("上传文件夹"): + folder_files = gr.File(label="添加文件", + file_count="directory", + show_label=False) + load_folder_button = gr.Button("上传并加载到知识库") + with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 2581b0264..e571ac54f 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -26,13 +26,14 @@ def knownledge_tovec_st(filename): """ Use sentence transformers to embedding the document. https://github.com/UKPLab/sentence-transformers """ - from pilot.configs.model_config import llm_model_config - embeddings = HuggingFaceEmbeddings(model=llm_model_config["sentence-transforms"]) + from pilot.configs.model_config import LLM_MODEL_CONFIG + embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter(knownledge) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) return docsearch