feature:web server knowledge embedding

This commit is contained in:
chenketing
2023-05-14 22:12:20 +08:00
parent 88d23c2147
commit 8ee288a8fd
21 changed files with 42 additions and 5 deletions

View File

@@ -40,4 +40,8 @@ DB_SETTINGS = {
"password": "aa123456", "password": "aa123456",
"host": "127.0.0.1", "host": "127.0.0.1",
"port": 3306 "port": 3306
} }
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")
MODEL_NAME_PATH = "model/all-MiniLM-L6-v2"

View File

@@ -3,6 +3,7 @@
import argparse import argparse
import os import os
import shutil
import uuid import uuid
import json import json
import time import time
@@ -10,9 +11,10 @@ import gradio as gr
import datetime import datetime
import requests import requests
from urllib.parse import urljoin from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, MODEL_NAME_PATH, VS_ROOT_PATH
from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
@@ -481,7 +483,15 @@ def build_single_model_ui():
[state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list [state, chatbot] + btn_list
) )
load_file_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name])
# load_folder_button.click(get_vector_store,
# show_progress=True,
# inputs=[vs_name, folder_files, 100 , chatbot, vs_add,
# vs_add],
# outputs=["db-out", folder_files, chatbot])
return state, chatbot, textbox, send_btn, button_row, parameter_row return state, chatbot, textbox, send_btn, button_row, parameter_row
@@ -520,6 +530,23 @@ def build_webdemo():
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo return demo
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
knowledge_embedding = PDFEmbedding(file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=MODEL_NAME_PATH,
vector_store_config={"vector_store_name": vs_id,
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
knowledge_embedding.source_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename + ".vectordb")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")

View File

@@ -54,8 +54,14 @@ class SourceEmbedding(ABC):
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb") self.vector_store_config["vector_store_name"] + ".vectordb")
vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir) self.vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir)
vector_store.persist() self.vector_store.persist()
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_store.similarity_search(doc, topk)
def source_embedding(self): def source_embedding(self):
if 'read' in registered_methods: if 'read' in registered_methods: