mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
update:merge
This commit is contained in:
@@ -6,6 +6,7 @@ import os
|
||||
import shutil
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
@@ -14,13 +15,17 @@ from urllib.parse import urljoin
|
||||
|
||||
from langchain import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
@@ -30,6 +35,8 @@ from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
@@ -67,7 +74,15 @@ priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
# 加载插件
|
||||
CFG= Config()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"host": CFG.LOCAL_DB_HOST,
|
||||
"port": CFG.LOCAL_DB_PORT
|
||||
}
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
@@ -178,7 +193,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
@@ -282,7 +297,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers, json=payload, timeout=120)
|
||||
|
||||
print(response.json())
|
||||
@@ -330,7 +345,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
|
||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers, json=payload, stream=True, timeout=20)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
@@ -606,12 +621,11 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
# dbs = get_database_list()
|
||||
|
||||
# 加载插件
|
||||
# 配置初始化
|
||||
cfg = Config()
|
||||
|
||||
dbs = get_database_list()
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
|
Reference in New Issue
Block a user