feature:db_summary

This commit is contained in:
aries-ckt
2023-06-01 14:02:23 +08:00
parent 3a46dfd3c2
commit b10269a550
30 changed files with 728 additions and 57 deletions

View File

@@ -4,7 +4,7 @@
from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
@@ -53,3 +53,17 @@ class KnownLedgeBaseQA:
print("new prompt length:" + str(len(prompt)))
return prompt
@staticmethod
def build_db_summary_prompt(query, db_profile_summary, state):
prompt_template = PromptTemplate(
template=conv_db_summary_templates,
input_variables=["db_input", "db_profile_summary"],
)
# context = [d.page_content for d in docs]
result = prompt_template.format(
db_profile_summary=db_profile_summary, db_input=query
)
state.messages[-2][1] = result
prompt = state.get_prompt()
return prompt

View File

@@ -3,16 +3,14 @@
import traceback
import argparse
import datetime
import json
import os
import shutil
import sys
import time
import uuid
import gradio as gr
import requests
from pilot.summary.db_summary_client import DBSummaryClient
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -27,13 +25,9 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_sql_mode,
conversation_types,
chat_mode_title,
@@ -41,19 +35,15 @@ from pilot.conversation import (
)
from pilot.common.plugins import scan_plugins
from pilot.prompts.generator import PluginPromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.utils import build_logger
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
@@ -75,6 +65,7 @@ vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
# db_summary = {"dbsummary": ""}
priority = {"vicuna-13b": "aaa"}
@@ -333,7 +324,7 @@ def http_bot(
state.messages[-1][-1] = "Error:" + str(e)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
if state.messages[-1][-1].endwith(""):
if state.messages[-1][-1].endswith(""):
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@@ -420,6 +411,8 @@ def build_single_model_ui():
name="db_selector"
).style(container=False)
db_selector.change(fn=db_selector_changed, inputs=db_selector)
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
@@ -615,6 +608,10 @@ def save_vs_name(vs_name):
return vs_name
def db_selector_changed(dbname):
DBSummaryClient.db_summary_embedding(dbname)
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):