add gradio tem

This commit is contained in:
csunny 2023-05-05 22:14:28 +08:00
parent 205eab7268
commit 7da910b642
4 changed files with 35 additions and 16 deletions

View File

@ -40,8 +40,8 @@ def get_answer(q):
return response.response return response.response
def get_similar(q): def get_similar(q):
from pilot.vector_store.extract_tovec import knownledge_tovec from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
docsearch = knownledge_tovec("./datasets/plan.md") docsearch = knownledge_tovec_st("./datasets/plan.md")
docs = docsearch.similarity_search_with_score(q, k=1) docs = docsearch.similarity_search_with_score(q, k=1)
for doc in docs: for doc in docs:

View File

@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Mapping, Optional, List from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from configs.model_config import * from pilot.configs.model_config import *
class VicunaRequestLLM(LLM): class VicunaRequestLLM(LLM):

View File

@ -12,7 +12,7 @@ import requests
from urllib.parse import urljoin from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
@ -42,6 +42,7 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False enable_moderation = False
models = [] models = []
dbs = [] dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
priority = { priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
@ -255,7 +256,7 @@ def build_single_model_ui():
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
""" """
vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State(vs_list) vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State()
state = gr.State() state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown") gr.Markdown(notice_markdown, elem_id="notice_markdown")
@ -279,14 +280,6 @@ def build_single_model_ui():
) )
tabs = gr.Tabs() tabs = gr.Tabs()
with tabs: with tabs:
with gr.TabItem("知识问答", elem_id="QA"):
doc2vec = gr.Column(visible=False)
with doc2vec:
mode = gr.Radio(["默认知识库对话", "新增知识库"])
vs_setting = gr.Accordion("配置知识库")
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
select_vs = gr.Dropdown()
with gr.TabItem("SQL生成与诊断", elem_id="SQL"): with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database # TODO A selector to choose database
with gr.Row(elem_id="db_selector"): with gr.Row(elem_id="db_selector"):
@ -296,6 +289,31 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "", value=dbs[0] if len(models) > 0 else "",
interactive=True, interactive=True,
show_label=True).style(container=False) show_label=True).style(container=False)
with gr.TabItem("知识问答", elem_id="QA"):
mode = gr.Radio(["默认知识库对话", "新增知识库"])
vs_setting = gr.Accordion("配置知识库")
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
vs_add = gr.Button("添加为新知识库")
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
show_label=False
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件",
file_count="directory",
show_label=False)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks(): with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)

View File

@ -26,13 +26,14 @@ def knownledge_tovec_st(filename):
""" Use sentence transformers to embedding the document. """ Use sentence transformers to embedding the document.
https://github.com/UKPLab/sentence-transformers https://github.com/UKPLab/sentence-transformers
""" """
from pilot.configs.model_config import llm_model_config from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model=llm_model_config["sentence-transforms"]) embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
with open(filename, "r") as f: with open(filename, "r") as f:
knownledge = f.read() knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter(knownledge) texts = text_splitter(knownledge)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
return docsearch return docsearch