mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
feature:db_summary
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
||||
from pilot.logs import logger
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
@@ -53,3 +53,17 @@ class KnownLedgeBaseQA:
|
||||
print("new prompt length:" + str(len(prompt)))
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def build_db_summary_prompt(query, db_profile_summary, state):
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_db_summary_templates,
|
||||
input_variables=["db_input", "db_profile_summary"],
|
||||
)
|
||||
# context = [d.page_content for d in docs]
|
||||
result = prompt_template.format(
|
||||
db_profile_summary=db_profile_summary, db_input=query
|
||||
)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
return prompt
|
||||
|
@@ -3,16 +3,14 @@
|
||||
import traceback
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@@ -27,13 +25,9 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.conversation import (
|
||||
SeparatorStyle,
|
||||
conv_qa_prompt_template,
|
||||
conv_templates,
|
||||
conversation_sql_mode,
|
||||
conversation_types,
|
||||
chat_mode_title,
|
||||
@@ -41,19 +35,15 @@ from pilot.conversation import (
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins
|
||||
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger, server_error_msg
|
||||
from pilot.utils import build_logger
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
knownledge_tovec_st,
|
||||
load_knownledge_from_doc,
|
||||
)
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.language.translation_handler import get_lang_text
|
||||
@@ -75,6 +65,7 @@ vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
|
||||
autogpt = False
|
||||
vector_store_client = None
|
||||
vector_store_name = {"vs_name": ""}
|
||||
# db_summary = {"dbsummary": ""}
|
||||
|
||||
priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
@@ -333,7 +324,7 @@ def http_bot(
|
||||
state.messages[-1][-1] = "Error:" + str(e)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
if state.messages[-1][-1].endwith("▌"):
|
||||
if state.messages[-1][-1].endswith("▌"):
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
@@ -420,6 +411,8 @@ def build_single_model_ui():
|
||||
name="db_selector"
|
||||
).style(container=False)
|
||||
|
||||
db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||
|
||||
sql_mode = gr.Radio(
|
||||
[
|
||||
get_lang_text("sql_generate_mode_direct"),
|
||||
@@ -615,6 +608,10 @@ def save_vs_name(vs_name):
|
||||
return vs_name
|
||||
|
||||
|
||||
def db_selector_changed(dbname):
|
||||
DBSummaryClient.db_summary_embedding(dbname)
|
||||
|
||||
|
||||
def knowledge_embedding_store(vs_id, files):
|
||||
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
||||
|
Reference in New Issue
Block a user