mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
feat(editor): DB GPT Editor Api
1.Add DB GPT Editor Api 2.Remove Old web server file
This commit is contained in:
@@ -131,7 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
|||||||
# Generic plugins
|
# Generic plugins
|
||||||
plugins_path_path = Path(PLUGINS_DIR)
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
|
|
||||||
|
|
||||||
for plugin in plugins_path_path.glob("*.zip"):
|
for plugin in plugins_path_path.glob("*.zip"):
|
||||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||||
for module in moduleList:
|
for module in moduleList:
|
||||||
|
@@ -11,7 +11,7 @@ class DBConfig(BaseModel):
|
|||||||
db_pwd: str = ""
|
db_pwd: str = ""
|
||||||
comment: str = ""
|
comment: str = ""
|
||||||
|
|
||||||
class DbTypeInfo(BaseModel):
|
|
||||||
db_type:str
|
|
||||||
is_file_db: bool = False
|
|
||||||
|
|
||||||
|
class DbTypeInfo(BaseModel):
|
||||||
|
db_type: str
|
||||||
|
is_file_db: bool = False
|
||||||
|
@@ -47,7 +47,8 @@ class DuckdbConnectConfig:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("add db connect info error1!" + str(e))
|
print("add db connect info error1!" + str(e))
|
||||||
|
|
||||||
def update_db_info(self,
|
def update_db_info(
|
||||||
|
self,
|
||||||
db_name,
|
db_name,
|
||||||
db_type,
|
db_type,
|
||||||
db_path: str = "",
|
db_path: str = "",
|
||||||
@@ -55,15 +56,20 @@ class DuckdbConnectConfig:
|
|||||||
db_port: int = 0,
|
db_port: int = 0,
|
||||||
db_user: str = "",
|
db_user: str = "",
|
||||||
db_pwd: str = "",
|
db_pwd: str = "",
|
||||||
comment: str = "" ):
|
comment: str = "",
|
||||||
|
):
|
||||||
old_db_conf = self.get_db_config(db_name)
|
old_db_conf = self.get_db_config(db_name)
|
||||||
if old_db_conf:
|
if old_db_conf:
|
||||||
try:
|
try:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
if not db_path:
|
if not db_path:
|
||||||
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'")
|
cursor.execute(
|
||||||
|
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'")
|
cursor.execute(
|
||||||
|
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -79,7 +85,6 @@ class DuckdbConnectConfig:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise "Unusable duckdb database path:" + path
|
raise "Unusable duckdb database path:" + path
|
||||||
|
|
||||||
|
|
||||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||||
try:
|
try:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
|
@@ -122,14 +122,16 @@ class ConnectManager:
|
|||||||
return self.storage.delete_db(db_name)
|
return self.storage.delete_db(db_name)
|
||||||
|
|
||||||
def edit_db(self, db_info: DBConfig):
|
def edit_db(self, db_info: DBConfig):
|
||||||
return self.storage.update_db_info(db_info.db_name,
|
return self.storage.update_db_info(
|
||||||
db_info.db_type,
|
db_info.db_name,
|
||||||
db_info.file_path,
|
db_info.db_type,
|
||||||
db_info.db_host,
|
db_info.file_path,
|
||||||
db_info.db_port,
|
db_info.db_host,
|
||||||
db_info.db_user,
|
db_info.db_port,
|
||||||
db_info.db_pwd,
|
db_info.db_user,
|
||||||
db_info.comment)
|
db_info.db_pwd,
|
||||||
|
db_info.comment,
|
||||||
|
)
|
||||||
|
|
||||||
def add_db(self, db_info: DBConfig):
|
def add_db(self, db_info: DBConfig):
|
||||||
print(f"add_db:{db_info.__dict__}")
|
print(f"add_db:{db_info.__dict__}")
|
||||||
@@ -140,7 +142,6 @@ class ConnectManager:
|
|||||||
db_info.db_name, db_info.db_type, db_info.file_path
|
db_info.db_name, db_info.db_type, db_info.file_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
self.storage.add_url_db(
|
self.storage.add_url_db(
|
||||||
db_info.db_name,
|
db_info.db_name,
|
||||||
db_info.db_type,
|
db_info.db_type,
|
||||||
@@ -151,7 +152,11 @@ class ConnectManager:
|
|||||||
db_info.comment,
|
db_info.comment,
|
||||||
)
|
)
|
||||||
# async embedding
|
# async embedding
|
||||||
thread = threading.Thread(target=self.db_summary_client.db_summary_embedding(db_info.db_name, db_info.db_type))
|
thread = threading.Thread(
|
||||||
|
target=self.db_summary_client.db_summary_embedding(
|
||||||
|
db_info.db_name, db_info.db_type
|
||||||
|
)
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Add db connect info error!" + str(e))
|
raise ValueError("Add db connect info error!" + str(e))
|
||||||
|
@@ -30,7 +30,11 @@ class DuckDbConnect(RDBMSDatabase):
|
|||||||
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
|
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self):
|
||||||
cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"))
|
cursor = self.session.execute(
|
||||||
|
text(
|
||||||
|
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
|
||||||
|
)
|
||||||
|
)
|
||||||
users = cursor.fetchall()
|
users = cursor.fetchall()
|
||||||
return [(user[0], user[1]) for user in users]
|
return [(user[0], user[1]) for user in users]
|
||||||
|
|
||||||
@@ -40,6 +44,7 @@ class DuckDbConnect(RDBMSDatabase):
|
|||||||
def get_collation(self):
|
def get_collation(self):
|
||||||
"""Get collation."""
|
"""Get collation."""
|
||||||
return "UTF-8"
|
return "UTF-8"
|
||||||
|
|
||||||
def get_charset(self):
|
def get_charset(self):
|
||||||
return "UTF-8"
|
return "UTF-8"
|
||||||
|
|
||||||
|
@@ -1,25 +1,18 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Request,
|
Request,
|
||||||
Body,
|
Body,
|
||||||
status,
|
|
||||||
HTTPException,
|
|
||||||
Response,
|
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse, HTMLResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
ConversationVo,
|
ConversationVo,
|
||||||
MessageVo,
|
MessageVo,
|
||||||
@@ -38,6 +31,8 @@ from pilot.utils import build_logger
|
|||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
|
from pilot.openapi.base import validation_exception_handler
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@@ -47,14 +42,6 @@ knowledge_service = KnowledgeService()
|
|||||||
|
|
||||||
model_semaphore = None
|
model_semaphore = None
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
|
||||||
|
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
||||||
message = ""
|
|
||||||
for error in exc.errors():
|
|
||||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
|
||||||
return Result.faild(code="E0001", msg=message)
|
|
||||||
|
|
||||||
|
|
||||||
def __get_conv_user_message(conversations: dict):
|
def __get_conv_user_message(conversations: dict):
|
||||||
@@ -121,7 +108,9 @@ async def db_support_types():
|
|||||||
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
|
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
|
||||||
db_type_infos = []
|
db_type_infos = []
|
||||||
for type in support_types:
|
for type in support_types:
|
||||||
db_type_infos.append(DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db()))
|
db_type_infos.append(
|
||||||
|
DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
|
||||||
|
)
|
||||||
return Result[DbTypeInfo].succ(db_type_infos)
|
return Result[DbTypeInfo].succ(db_type_infos)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,7 +158,7 @@ async def dialogue_scenes():
|
|||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
||||||
):
|
):
|
||||||
conv_vo = __new_conversation(chat_mode, user_id)
|
conv_vo = __new_conversation(chat_mode, user_id)
|
||||||
return Result.succ(conv_vo)
|
return Result.succ(conv_vo)
|
||||||
|
0
pilot/openapi/api_v1/editor/__init__.py
Normal file
0
pilot/openapi/api_v1/editor/__init__.py
Normal file
80
pilot/openapi/api_v1/editor/api_editor_v1.py
Normal file
80
pilot/openapi/api_v1/editor/api_editor_v1.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Request,
|
||||||
|
Body,
|
||||||
|
BackgroundTasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
|
||||||
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
from pilot.openapi.api_view_model import (
|
||||||
|
Result,
|
||||||
|
ConversationVo,
|
||||||
|
MessageVo,
|
||||||
|
ChatSceneVo,
|
||||||
|
)
|
||||||
|
from pilot.openapi.editor_view_model import (
|
||||||
|
ChatDbRounds,
|
||||||
|
ChartDetail,
|
||||||
|
ChatChartEditContext,
|
||||||
|
ChatSqlEditContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData
|
||||||
|
|
||||||
|
from pilot.scene.chat_db.auto_execute.out_parser import SqlAction
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
CFG = Config()
|
||||||
|
CHAT_FACTORY = ChatFactory()
|
||||||
|
logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||||
|
async def get_editor_sql_rounds(con_uid: str):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/editor/sql", response_model=Result[SqlAction])
|
||||||
|
async def get_editor_sql(con_uid: str, round: int):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/editor/chart/details", response_model=Result[ChartDetail])
|
||||||
|
async def get_editor_sql_rounds(con_uid: str):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/editor/chart", response_model=Result[ChartDetail])
|
||||||
|
async def get_editor_chart(con_uid: str, chart_uid: str):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
|
||||||
|
async def get_editor_chart(db_name: str, sql: str):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/editor/chart/run", response_model=Result[ChartData])
|
||||||
|
async def get_editor_chart(db_name: str, sql: str):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||||
|
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/sql/editor/submit", response_model=Result[bool])
|
||||||
|
async def chart_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||||
|
return Result.succ(None)
|
28
pilot/openapi/base.py
Normal file
28
pilot/openapi/base.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Request,
|
||||||
|
Body,
|
||||||
|
status,
|
||||||
|
HTTPException,
|
||||||
|
Response,
|
||||||
|
BackgroundTasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse, HTMLResponse
|
||||||
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
|
from pilot.openapi.api_view_model import (
|
||||||
|
Result,
|
||||||
|
ConversationVo,
|
||||||
|
MessageVo,
|
||||||
|
ChatSceneVo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
message = ""
|
||||||
|
for error in exc.errors():
|
||||||
|
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||||
|
return Result.faild(code="E0001", msg=message)
|
42
pilot/openapi/editor_view_model.py
Normal file
42
pilot/openapi/editor_view_model.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import TypeVar, Union, List, Generic, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDbRounds(BaseModel):
|
||||||
|
round: int
|
||||||
|
db_name: str
|
||||||
|
round_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChartDetail(BaseModel):
|
||||||
|
chart_uid: str
|
||||||
|
chart_type: str
|
||||||
|
db_name: str
|
||||||
|
chart_name: str
|
||||||
|
chart_value: str
|
||||||
|
chat_round: int # defualt last round
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChartEditContext(BaseModel):
|
||||||
|
conv_uid: str
|
||||||
|
conv_round: int
|
||||||
|
chart_uid: str
|
||||||
|
|
||||||
|
old_sql: str
|
||||||
|
new_sql: str
|
||||||
|
comment: str
|
||||||
|
gmt_create: int
|
||||||
|
|
||||||
|
new_view_info: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSqlEditContext(BaseModel):
|
||||||
|
conv_uid: str
|
||||||
|
conv_round: int
|
||||||
|
|
||||||
|
old_sql: str
|
||||||
|
new_sql: str
|
||||||
|
comment: str
|
||||||
|
gmt_create: int
|
||||||
|
|
||||||
|
new_view_info: str
|
@@ -7,6 +7,7 @@ import sys
|
|||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
# from pilot.configs.model_config import (
|
# from pilot.configs.model_config import (
|
||||||
# DATASETS_DIR,
|
# DATASETS_DIR,
|
||||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
@@ -5,10 +5,12 @@ import shutil
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
import signal
|
import signal
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
# from pilot.configs.model_config import (
|
# from pilot.configs.model_config import (
|
||||||
# DATASETS_DIR,
|
# DATASETS_DIR,
|
||||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
@@ -17,7 +19,7 @@ from pilot.configs.config import Config
|
|||||||
# )
|
# )
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
from pilot.server.webserver_base import server_init
|
from pilot.server.base import server_init
|
||||||
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi import FastAPI, applications
|
from fastapi import FastAPI, applications
|
||||||
|
@@ -1,74 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
code_highlight_css = """
|
|
||||||
#chatbot .hll { background-color: #ffffcc }
|
|
||||||
#chatbot .c { color: #408080; font-style: italic }
|
|
||||||
#chatbot .err { border: 1px solid #FF0000 }
|
|
||||||
#chatbot .k { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .o { color: #666666 }
|
|
||||||
#chatbot .ch { color: #408080; font-style: italic }
|
|
||||||
#chatbot .cm { color: #408080; font-style: italic }
|
|
||||||
#chatbot .cp { color: #BC7A00 }
|
|
||||||
#chatbot .cpf { color: #408080; font-style: italic }
|
|
||||||
#chatbot .c1 { color: #408080; font-style: italic }
|
|
||||||
#chatbot .cs { color: #408080; font-style: italic }
|
|
||||||
#chatbot .gd { color: #A00000 }
|
|
||||||
#chatbot .ge { font-style: italic }
|
|
||||||
#chatbot .gr { color: #FF0000 }
|
|
||||||
#chatbot .gh { color: #000080; font-weight: bold }
|
|
||||||
#chatbot .gi { color: #00A000 }
|
|
||||||
#chatbot .go { color: #888888 }
|
|
||||||
#chatbot .gp { color: #000080; font-weight: bold }
|
|
||||||
#chatbot .gs { font-weight: bold }
|
|
||||||
#chatbot .gu { color: #800080; font-weight: bold }
|
|
||||||
#chatbot .gt { color: #0044DD }
|
|
||||||
#chatbot .kc { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .kd { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .kn { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .kp { color: #008000 }
|
|
||||||
#chatbot .kr { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .kt { color: #B00040 }
|
|
||||||
#chatbot .m { color: #666666 }
|
|
||||||
#chatbot .s { color: #BA2121 }
|
|
||||||
#chatbot .na { color: #7D9029 }
|
|
||||||
#chatbot .nb { color: #008000 }
|
|
||||||
#chatbot .nc { color: #0000FF; font-weight: bold }
|
|
||||||
#chatbot .no { color: #880000 }
|
|
||||||
#chatbot .nd { color: #AA22FF }
|
|
||||||
#chatbot .ni { color: #999999; font-weight: bold }
|
|
||||||
#chatbot .ne { color: #D2413A; font-weight: bold }
|
|
||||||
#chatbot .nf { color: #0000FF }
|
|
||||||
#chatbot .nl { color: #A0A000 }
|
|
||||||
#chatbot .nn { color: #0000FF; font-weight: bold }
|
|
||||||
#chatbot .nt { color: #008000; font-weight: bold }
|
|
||||||
#chatbot .nv { color: #19177C }
|
|
||||||
#chatbot .ow { color: #AA22FF; font-weight: bold }
|
|
||||||
#chatbot .w { color: #bbbbbb }
|
|
||||||
#chatbot .mb { color: #666666 }
|
|
||||||
#chatbot .mf { color: #666666 }
|
|
||||||
#chatbot .mh { color: #666666 }
|
|
||||||
#chatbot .mi { color: #666666 }
|
|
||||||
#chatbot .mo { color: #666666 }
|
|
||||||
#chatbot .sa { color: #BA2121 }
|
|
||||||
#chatbot .sb { color: #BA2121 }
|
|
||||||
#chatbot .sc { color: #BA2121 }
|
|
||||||
#chatbot .dl { color: #BA2121 }
|
|
||||||
#chatbot .sd { color: #BA2121; font-style: italic }
|
|
||||||
#chatbot .s2 { color: #BA2121 }
|
|
||||||
#chatbot .se { color: #BB6622; font-weight: bold }
|
|
||||||
#chatbot .sh { color: #BA2121 }
|
|
||||||
#chatbot .si { color: #BB6688; font-weight: bold }
|
|
||||||
#chatbot .sx { color: #008000 }
|
|
||||||
#chatbot .sr { color: #BB6688 }
|
|
||||||
#chatbot .s1 { color: #BA2121 }
|
|
||||||
#chatbot .ss { color: #19177C }
|
|
||||||
#chatbot .bp { color: #008000 }
|
|
||||||
#chatbot .fm { color: #0000FF }
|
|
||||||
#chatbot .vc { color: #19177C }
|
|
||||||
#chatbot .vg { color: #19177C }
|
|
||||||
#chatbot .vi { color: #19177C }
|
|
||||||
#chatbot .vm { color: #19177C }
|
|
||||||
#chatbot .il { color: #666666 }
|
|
||||||
"""
|
|
||||||
# .highlight { background: #f8f8f8; }
|
|
@@ -1,166 +0,0 @@
|
|||||||
"""
|
|
||||||
Fork from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/gradio_patch.py
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from gradio.components import *
|
|
||||||
from markdown2 import Markdown
|
|
||||||
|
|
||||||
|
|
||||||
class _Keywords(Enum):
|
|
||||||
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
|
|
||||||
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
|
|
||||||
|
|
||||||
|
|
||||||
@document("style")
|
|
||||||
class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
|
||||||
"""
|
|
||||||
Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
|
|
||||||
Preprocessing: this component does *not* accept input.
|
|
||||||
Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
|
|
||||||
|
|
||||||
Demos: chatbot_simple, chatbot_multimodal
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
value: List[Tuple[str | None, str | None]] | Callable | None = None,
|
|
||||||
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
|
|
||||||
*,
|
|
||||||
label: str | None = None,
|
|
||||||
every: float | None = None,
|
|
||||||
show_label: bool = True,
|
|
||||||
visible: bool = True,
|
|
||||||
elem_id: str | None = None,
|
|
||||||
elem_classes: List[str] | str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
|
|
||||||
label: component name in interface.
|
|
||||||
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
|
||||||
show_label: if True, will display label.
|
|
||||||
visible: If False, component will be hidden.
|
|
||||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
|
||||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
|
||||||
"""
|
|
||||||
if color_map is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"The 'color_map' parameter has been deprecated.",
|
|
||||||
)
|
|
||||||
# self.md = utils.get_markdown_parser()
|
|
||||||
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
|
||||||
self.select: EventListenerMethod
|
|
||||||
"""
|
|
||||||
Event listener for when the user selects message from Chatbot.
|
|
||||||
Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
|
|
||||||
See EventData documentation on how to use this event data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
IOComponent.__init__(
|
|
||||||
self,
|
|
||||||
label=label,
|
|
||||||
every=every,
|
|
||||||
show_label=show_label,
|
|
||||||
visible=visible,
|
|
||||||
elem_id=elem_id,
|
|
||||||
elem_classes=elem_classes,
|
|
||||||
value=value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return {
|
|
||||||
"value": self.value,
|
|
||||||
"selectable": self.selectable,
|
|
||||||
**IOComponent.get_config(self),
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update(
|
|
||||||
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
|
|
||||||
label: str | None = None,
|
|
||||||
show_label: bool | None = None,
|
|
||||||
visible: bool | None = None,
|
|
||||||
):
|
|
||||||
updated_config = {
|
|
||||||
"label": label,
|
|
||||||
"show_label": show_label,
|
|
||||||
"visible": visible,
|
|
||||||
"value": value,
|
|
||||||
"__type__": "update",
|
|
||||||
}
|
|
||||||
return updated_config
|
|
||||||
|
|
||||||
def _process_chat_messages(
|
|
||||||
self, chat_message: str | Tuple | List | Dict | None
|
|
||||||
) -> str | Dict | None:
|
|
||||||
if chat_message is None:
|
|
||||||
return None
|
|
||||||
elif isinstance(chat_message, (tuple, list)):
|
|
||||||
mime_type = processing_utils.get_mimetype(chat_message[0])
|
|
||||||
return {
|
|
||||||
"name": chat_message[0],
|
|
||||||
"mime_type": mime_type,
|
|
||||||
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
|
||||||
"data": None, # These last two fields are filled in by the frontend
|
|
||||||
"is_file": True,
|
|
||||||
}
|
|
||||||
elif isinstance(
|
|
||||||
chat_message, dict
|
|
||||||
): # This happens for previously processed messages
|
|
||||||
return chat_message
|
|
||||||
elif isinstance(chat_message, str):
|
|
||||||
# return self.md.render(chat_message)
|
|
||||||
return str(self.md.convert(chat_message))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
|
||||||
|
|
||||||
def postprocess(
|
|
||||||
self,
|
|
||||||
y: List[
|
|
||||||
Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
|
|
||||||
],
|
|
||||||
) -> List[Tuple[str | Dict | None, str | Dict | None]]:
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
|
||||||
Returns:
|
|
||||||
List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
|
|
||||||
"""
|
|
||||||
if y is None:
|
|
||||||
return []
|
|
||||||
processed_messages = []
|
|
||||||
for message_pair in y:
|
|
||||||
assert isinstance(
|
|
||||||
message_pair, (tuple, list)
|
|
||||||
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
|
||||||
assert (
|
|
||||||
len(message_pair) == 2
|
|
||||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
|
||||||
processed_messages.append(
|
|
||||||
(
|
|
||||||
# self._process_chat_messages(message_pair[0]),
|
|
||||||
'<pre style="font-family: var(--font)">'
|
|
||||||
+ message_pair[0]
|
|
||||||
+ "</pre>",
|
|
||||||
self._process_chat_messages(message_pair[1]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return processed_messages
|
|
||||||
|
|
||||||
def style(self, height: int | None = None, **kwargs):
|
|
||||||
"""
|
|
||||||
This method can be used to change the appearance of the Chatbot component.
|
|
||||||
"""
|
|
||||||
if height is not None:
|
|
||||||
self._style["height"] = height
|
|
||||||
if kwargs.get("color_map") is not None:
|
|
||||||
warnings.warn("The 'color_map' parameter has been deprecated.")
|
|
||||||
|
|
||||||
Component.style(
|
|
||||||
self,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return self
|
|
@@ -9,7 +9,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_view_model import Result
|
from pilot.openapi.api_view_model import Result
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
|
||||||
from pilot.server.knowledge.service import KnowledgeService
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
@@ -1,71 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
|
||||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
|
||||||
|
|
||||||
CFG = Config()
|
|
||||||
|
|
||||||
|
|
||||||
class KnownLedgeBaseQA:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
k2v = KnownLedge2Vector()
|
|
||||||
self.vector_store = k2v.init_vector_store()
|
|
||||||
self.llm = VicunaLLM()
|
|
||||||
|
|
||||||
def get_similar_answer(self, query):
|
|
||||||
prompt = PromptTemplate(
|
|
||||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = self.vector_store.as_retriever(
|
|
||||||
search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
|
|
||||||
)
|
|
||||||
docs = retriever.get_relevant_documents(query=query)
|
|
||||||
|
|
||||||
context = [d.page_content for d in docs]
|
|
||||||
result = prompt.format(context="\n".join(context), question=query)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_knowledge_prompt(query, docs, state):
|
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
|
||||||
)
|
|
||||||
context = [d.page_content for d in docs]
|
|
||||||
result = prompt_template.format(context="\n".join(context), question=query)
|
|
||||||
state.messages[-2][1] = result
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
|
|
||||||
if len(prompt) > 4000:
|
|
||||||
logger.info("prompt length greater than 4000, rebuild")
|
|
||||||
context = context[:2000]
|
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
template=conv_qa_prompt_template,
|
|
||||||
input_variables=["context", "question"],
|
|
||||||
)
|
|
||||||
result = prompt_template.format(context="\n".join(context), question=query)
|
|
||||||
state.messages[-2][1] = result
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
print("new prompt length:" + str(len(prompt)))
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_db_summary_prompt(query, db_profile_summary, state):
|
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
template=conv_db_summary_templates,
|
|
||||||
input_variables=["db_input", "db_profile_summary"],
|
|
||||||
)
|
|
||||||
# context = [d.page_content for d in docs]
|
|
||||||
result = prompt_template.format(
|
|
||||||
db_profile_summary=db_profile_summary, db_input=query
|
|
||||||
)
|
|
||||||
state.messages[-2][1] = result
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
return prompt
|
|
@@ -1,703 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import threading
|
|
||||||
import traceback
|
|
||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
sys.path.append(ROOT_PATH)
|
|
||||||
|
|
||||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
|
||||||
|
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.configs.model_config import (
|
|
||||||
DATASETS_DIR,
|
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
LLM_MODEL_CONFIG,
|
|
||||||
LOGDIR,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.conversation import (
|
|
||||||
conversation_sql_mode,
|
|
||||||
conversation_types,
|
|
||||||
chat_mode_title,
|
|
||||||
default_conversation,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.vector_store.extract_tovec import (
|
|
||||||
get_vector_storelist,
|
|
||||||
knownledge_tovec_st,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.scene.base import ChatScene
|
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
|
||||||
from pilot.language.translation_handler import get_lang_text
|
|
||||||
from pilot.server.webserver_base import server_init
|
|
||||||
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import BackgroundTasks, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from fastapi import FastAPI, applications
|
|
||||||
from fastapi.openapi.docs import get_swagger_ui_html
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
|
||||||
|
|
||||||
# 加载插件
|
|
||||||
CFG = Config()
|
|
||||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
|
||||||
|
|
||||||
no_change_btn = gr.Button.update()
|
|
||||||
enable_btn = gr.Button.update(interactive=True)
|
|
||||||
disable_btn = gr.Button.update(interactive=True)
|
|
||||||
|
|
||||||
enable_moderation = False
|
|
||||||
models = []
|
|
||||||
dbs = []
|
|
||||||
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
|
|
||||||
autogpt = False
|
|
||||||
vector_store_client = None
|
|
||||||
vector_store_name = {"vs_name": ""}
|
|
||||||
# db_summary = {"dbsummary": ""}
|
|
||||||
|
|
||||||
priority = {"vicuna-13b": "aaa"}
|
|
||||||
|
|
||||||
CHAT_FACTORY = ChatFactory()
|
|
||||||
|
|
||||||
|
|
||||||
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
|
||||||
default_knowledge_base_dialogue = get_lang_text(
|
|
||||||
"knowledge_qa_type_default_knowledge_base_dialogue"
|
|
||||||
)
|
|
||||||
add_knowledge_base_dialogue = get_lang_text(
|
|
||||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
|
||||||
)
|
|
||||||
|
|
||||||
url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
|
|
||||||
|
|
||||||
knowledge_qa_type_list = [
|
|
||||||
llm_native_dialogue,
|
|
||||||
default_knowledge_base_dialogue,
|
|
||||||
add_knowledge_base_dialogue,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def swagger_monkey_patch(*args, **kwargs):
|
|
||||||
return get_swagger_ui_html(
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
|
||||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
origins = ["*"]
|
|
||||||
|
|
||||||
# 添加跨域中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=origins,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# app.mount("static", StaticFiles(directory="static"), name="static")
|
|
||||||
app.include_router(api_v1)
|
|
||||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def get_simlar(q):
|
|
||||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
|
||||||
|
|
||||||
contents = [dc.page_content for dc, _ in docs]
|
|
||||||
return "\n".join(contents)
|
|
||||||
|
|
||||||
|
|
||||||
def plugins_select_info():
|
|
||||||
plugins_infos: dict = {}
|
|
||||||
for plugin in CFG.plugins:
|
|
||||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
|
||||||
return plugins_infos
|
|
||||||
|
|
||||||
|
|
||||||
get_window_url_params = """
|
|
||||||
function() {
|
|
||||||
const params = new URLSearchParams(window.location.search);
|
|
||||||
url_params = Object.fromEntries(params);
|
|
||||||
console.log(url_params);
|
|
||||||
gradioURL = window.location.href
|
|
||||||
if (!gradioURL.endsWith('?__theme=dark')) {
|
|
||||||
window.location.replace(gradioURL + '?__theme=dark');
|
|
||||||
}
|
|
||||||
return url_params;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def load_demo(url_params, request: gr.Request):
|
|
||||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
|
||||||
|
|
||||||
# dbs = get_database_list()
|
|
||||||
dropdown_update = gr.Dropdown.update(visible=True)
|
|
||||||
if dbs:
|
|
||||||
gr.Dropdown.update(choices=dbs)
|
|
||||||
|
|
||||||
state = default_conversation.copy()
|
|
||||||
|
|
||||||
unique_id = uuid.uuid1()
|
|
||||||
state.conv_id = str(unique_id)
|
|
||||||
|
|
||||||
return (
|
|
||||||
state,
|
|
||||||
dropdown_update,
|
|
||||||
gr.Chatbot.update(visible=True),
|
|
||||||
gr.Textbox.update(visible=True),
|
|
||||||
gr.Button.update(visible=True),
|
|
||||||
gr.Row.update(visible=True),
|
|
||||||
gr.Accordion.update(visible=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_conv_log_filename():
|
|
||||||
t = datetime.datetime.now()
|
|
||||||
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def regenerate(state, request: gr.Request):
|
|
||||||
logger.info(f"regenerate. ip: {request.client.host}")
|
|
||||||
state.messages[-1][-1] = None
|
|
||||||
state.skip_next = False
|
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
|
||||||
|
|
||||||
|
|
||||||
def clear_history(request: gr.Request):
|
|
||||||
logger.info(f"clear_history. ip: {request.client.host}")
|
|
||||||
state = None
|
|
||||||
return (state, [], "") + (disable_btn,) * 5
|
|
||||||
|
|
||||||
|
|
||||||
def add_text(state, text, request: gr.Request):
|
|
||||||
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
|
||||||
if len(text) <= 0:
|
|
||||||
state.skip_next = True
|
|
||||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
|
||||||
|
|
||||||
""" Default support 4000 tokens, if tokens too lang, we will cut off """
|
|
||||||
text = text[:4000]
|
|
||||||
state.append_message(state.roles[0], text)
|
|
||||||
state.append_message(state.roles[1], None)
|
|
||||||
state.skip_next = False
|
|
||||||
### TODO
|
|
||||||
state.last_user_input = text
|
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
|
||||||
|
|
||||||
|
|
||||||
def post_process_code(code):
|
|
||||||
sep = "\n```"
|
|
||||||
if sep in code:
|
|
||||||
blocks = code.split(sep)
|
|
||||||
if len(blocks) % 2 == 1:
|
|
||||||
for i in range(1, len(blocks), 2):
|
|
||||||
blocks[i] = blocks[i].replace("\\_", "_")
|
|
||||||
code = sep.join(blocks)
|
|
||||||
return code
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
|
||||||
if chat_mode_title["chat_use_plugin"] == selected:
|
|
||||||
return ChatScene.ChatExecution
|
|
||||||
elif chat_mode_title["sql_generate_diagnostics"] == selected:
|
|
||||||
sql_mode = param
|
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
|
||||||
return ChatScene.ChatWithDbExecute
|
|
||||||
else:
|
|
||||||
return ChatScene.ChatWithDbQA
|
|
||||||
else:
|
|
||||||
mode = param
|
|
||||||
if mode == conversation_types["default_knownledge"]:
|
|
||||||
return ChatScene.ChatDefaultKnowledge
|
|
||||||
elif mode == conversation_types["custome"]:
|
|
||||||
return ChatScene.ChatNewKnowledge
|
|
||||||
elif mode == conversation_types["url"]:
|
|
||||||
return ChatScene.ChatUrlKnowledge
|
|
||||||
else:
|
|
||||||
return ChatScene.ChatNormal
|
|
||||||
|
|
||||||
|
|
||||||
def chatbot_callback(state, message):
|
|
||||||
print(f"chatbot_callback:{message}")
|
|
||||||
state.messages[-1][-1] = f"{message}"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
|
|
||||||
|
|
||||||
def http_bot(
|
|
||||||
state,
|
|
||||||
selected,
|
|
||||||
temperature,
|
|
||||||
max_new_tokens,
|
|
||||||
plugin_selector,
|
|
||||||
mode,
|
|
||||||
sql_mode,
|
|
||||||
db_selector,
|
|
||||||
url_input,
|
|
||||||
knowledge_name,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
|
|
||||||
)
|
|
||||||
if chat_mode_title["sql_generate_diagnostics"] == selected:
|
|
||||||
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
|
||||||
elif chat_mode_title["chat_use_plugin"] == selected:
|
|
||||||
scene: ChatScene = get_chat_mode(selected)
|
|
||||||
else:
|
|
||||||
scene: ChatScene = get_chat_mode(selected, mode)
|
|
||||||
|
|
||||||
print(f"chat scene:{scene.value}")
|
|
||||||
|
|
||||||
if ChatScene.ChatWithDbExecute == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"db_name": db_selector,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatWithDbQA == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"db_name": db_selector,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatExecution == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"plugin_selector": plugin_selector,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatNormal == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatDefaultKnowledge == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatNewKnowledge == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
"knowledge_name": knowledge_name,
|
|
||||||
}
|
|
||||||
elif ChatScene.ChatUrlKnowledge == scene:
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": state.conv_id,
|
|
||||||
"user_input": state.last_user_input,
|
|
||||||
"url": url_input,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value(), **chat_param)
|
|
||||||
if not chat.prompt_template.stream_out:
|
|
||||||
logger.info("not stream out, wait model response!")
|
|
||||||
state.messages[-1][-1] = chat.nostream_call()
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
else:
|
|
||||||
logger.info("stream out start!")
|
|
||||||
try:
|
|
||||||
response = chat.stream_call()
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
|
||||||
chunk, chat.skip_echo_len
|
|
||||||
)
|
|
||||||
state.messages[-1][-1] = msg
|
|
||||||
chat.current_message.add_ai_message(msg)
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
chat.memory.append(chat.current_message)
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
state.messages[-1][
|
|
||||||
-1
|
|
||||||
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
|
|
||||||
|
|
||||||
block_css = (
|
|
||||||
code_highlight_css
|
|
||||||
+ """
|
|
||||||
pre {
|
|
||||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
|
||||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
|
||||||
white-space: -pre-wrap; /* Opera 4-6 */
|
|
||||||
white-space: -o-pre-wrap; /* Opera 7 */
|
|
||||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
|
||||||
}
|
|
||||||
#notice_markdown th {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def change_sql_mode(sql_mode):
|
|
||||||
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
|
|
||||||
return gr.update(visible=True)
|
|
||||||
else:
|
|
||||||
return gr.update(visible=False)
|
|
||||||
|
|
||||||
|
|
||||||
def change_mode(mode):
|
|
||||||
if mode in [add_knowledge_base_dialogue]:
|
|
||||||
return gr.update(visible=True)
|
|
||||||
else:
|
|
||||||
return gr.update(visible=False)
|
|
||||||
|
|
||||||
|
|
||||||
def build_single_model_ui():
|
|
||||||
notice_markdown = get_lang_text("db_gpt_introduction")
|
|
||||||
learn_more_markdown = get_lang_text("learn_more_markdown")
|
|
||||||
|
|
||||||
state = gr.State()
|
|
||||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
|
||||||
|
|
||||||
with gr.Accordion(
|
|
||||||
get_lang_text("model_control_param"), open=False, visible=False
|
|
||||||
) as parameter_row:
|
|
||||||
temperature = gr.Slider(
|
|
||||||
minimum=0.0,
|
|
||||||
maximum=1.0,
|
|
||||||
value=0.7,
|
|
||||||
step=0.1,
|
|
||||||
interactive=True,
|
|
||||||
label="Temperature",
|
|
||||||
)
|
|
||||||
|
|
||||||
max_output_tokens = gr.Slider(
|
|
||||||
minimum=0,
|
|
||||||
maximum=1024,
|
|
||||||
value=512,
|
|
||||||
step=64,
|
|
||||||
interactive=True,
|
|
||||||
label=get_lang_text("max_input_token_size"),
|
|
||||||
)
|
|
||||||
|
|
||||||
tabs = gr.Tabs()
|
|
||||||
|
|
||||||
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
|
||||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
||||||
return evt.value
|
|
||||||
|
|
||||||
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
|
||||||
tabs.select(on_select, None, selected)
|
|
||||||
|
|
||||||
with tabs:
|
|
||||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
|
||||||
with tab_qa:
|
|
||||||
mode = gr.Radio(
|
|
||||||
[
|
|
||||||
llm_native_dialogue,
|
|
||||||
default_knowledge_base_dialogue,
|
|
||||||
add_knowledge_base_dialogue,
|
|
||||||
url_knowledge_dialogue,
|
|
||||||
],
|
|
||||||
show_label=False,
|
|
||||||
value=llm_native_dialogue,
|
|
||||||
)
|
|
||||||
vs_setting = gr.Accordion(
|
|
||||||
get_lang_text("configure_knowledge_base"), open=False, visible=False
|
|
||||||
)
|
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
|
||||||
|
|
||||||
url_input = gr.Textbox(
|
|
||||||
label=get_lang_text("url_input_label"),
|
|
||||||
lines=1,
|
|
||||||
interactive=True,
|
|
||||||
visible=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def show_url_input(evt: gr.SelectData):
|
|
||||||
if evt.value == url_knowledge_dialogue:
|
|
||||||
return gr.update(visible=True)
|
|
||||||
else:
|
|
||||||
return gr.update(visible=False)
|
|
||||||
|
|
||||||
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
|
|
||||||
|
|
||||||
with vs_setting:
|
|
||||||
vs_name = gr.Textbox(
|
|
||||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
|
||||||
)
|
|
||||||
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
|
|
||||||
with gr.Column() as doc2vec:
|
|
||||||
gr.Markdown(get_lang_text("add_file_to_klg"))
|
|
||||||
with gr.Tab(get_lang_text("upload_file")):
|
|
||||||
files = gr.File(
|
|
||||||
label=get_lang_text("add_file"),
|
|
||||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
|
||||||
file_count="multiple",
|
|
||||||
allow_flagged_uploads=True,
|
|
||||||
show_label=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
load_file_button = gr.Button(
|
|
||||||
get_lang_text("upload_and_load_to_klg")
|
|
||||||
)
|
|
||||||
with gr.Tab(get_lang_text("upload_folder")):
|
|
||||||
folder_files = gr.File(
|
|
||||||
label=get_lang_text("add_folder"),
|
|
||||||
accept_multiple_files=True,
|
|
||||||
file_count="directory",
|
|
||||||
show_label=False,
|
|
||||||
)
|
|
||||||
load_folder_button = gr.Button(
|
|
||||||
get_lang_text("upload_and_load_to_klg")
|
|
||||||
)
|
|
||||||
|
|
||||||
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
|
|
||||||
with tab_sql:
|
|
||||||
# TODO A selector to choose database
|
|
||||||
with gr.Row(elem_id="db_selector"):
|
|
||||||
db_selector = gr.Dropdown(
|
|
||||||
label=get_lang_text("please_choose_database"),
|
|
||||||
choices=dbs,
|
|
||||||
value=dbs[0] if len(models) > 0 else "",
|
|
||||||
interactive=True,
|
|
||||||
show_label=True,
|
|
||||||
).style(container=False)
|
|
||||||
|
|
||||||
# db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
|
||||||
|
|
||||||
sql_mode = gr.Radio(
|
|
||||||
[
|
|
||||||
get_lang_text("sql_generate_mode_direct"),
|
|
||||||
get_lang_text("sql_generate_mode_none"),
|
|
||||||
],
|
|
||||||
show_label=False,
|
|
||||||
value=get_lang_text("sql_generate_mode_none"),
|
|
||||||
)
|
|
||||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
|
||||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
|
||||||
|
|
||||||
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
|
|
||||||
# tab_plugin.select(change_func)
|
|
||||||
with tab_plugin:
|
|
||||||
print("tab_plugin in...")
|
|
||||||
with gr.Row(elem_id="plugin_selector"):
|
|
||||||
# TODO
|
|
||||||
plugin_selector = gr.Dropdown(
|
|
||||||
label=get_lang_text("select_plugin"),
|
|
||||||
choices=list(plugins_select_info().keys()),
|
|
||||||
value="",
|
|
||||||
interactive=True,
|
|
||||||
show_label=True,
|
|
||||||
type="value",
|
|
||||||
).style(container=False)
|
|
||||||
|
|
||||||
def plugin_change(
|
|
||||||
evt: gr.SelectData,
|
|
||||||
): # SelectData is a subclass of EventData
|
|
||||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
||||||
print(f"user plugin:{plugins_select_info().get(evt.value)}")
|
|
||||||
return plugins_select_info().get(evt.value)
|
|
||||||
|
|
||||||
plugin_selected = gr.Textbox(
|
|
||||||
show_label=False, visible=False, placeholder="Selected"
|
|
||||||
)
|
|
||||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
|
||||||
|
|
||||||
with gr.Blocks():
|
|
||||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=20):
|
|
||||||
textbox = gr.Textbox(
|
|
||||||
show_label=False,
|
|
||||||
placeholder="Enter text and press ENTER",
|
|
||||||
visible=False,
|
|
||||||
).style(container=False)
|
|
||||||
with gr.Column(scale=2, min_width=50):
|
|
||||||
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
|
|
||||||
|
|
||||||
with gr.Row(visible=False) as button_row:
|
|
||||||
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
|
|
||||||
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
|
|
||||||
|
|
||||||
gr.Markdown(learn_more_markdown)
|
|
||||||
|
|
||||||
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
|
|
||||||
|
|
||||||
btn_list = [regenerate_btn, clear_btn]
|
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
|
||||||
http_bot,
|
|
||||||
[state, selected, temperature, max_output_tokens] + params,
|
|
||||||
[state, chatbot] + btn_list,
|
|
||||||
)
|
|
||||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
|
||||||
|
|
||||||
textbox.submit(
|
|
||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
|
||||||
).then(
|
|
||||||
http_bot,
|
|
||||||
[state, selected, temperature, max_output_tokens] + params,
|
|
||||||
[state, chatbot] + btn_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
send_btn.click(
|
|
||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
|
||||||
).then(
|
|
||||||
http_bot,
|
|
||||||
[state, selected, temperature, max_output_tokens] + params,
|
|
||||||
[state, chatbot] + btn_list,
|
|
||||||
)
|
|
||||||
vs_add.click(
|
|
||||||
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
|
|
||||||
)
|
|
||||||
load_file_button.click(
|
|
||||||
fn=knowledge_embedding_store,
|
|
||||||
show_progress=True,
|
|
||||||
inputs=[vs_name, files],
|
|
||||||
outputs=[vs_name],
|
|
||||||
)
|
|
||||||
load_folder_button.click(
|
|
||||||
fn=knowledge_embedding_store,
|
|
||||||
show_progress=True,
|
|
||||||
inputs=[vs_name, folder_files],
|
|
||||||
outputs=[vs_name],
|
|
||||||
)
|
|
||||||
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
|
||||||
|
|
||||||
|
|
||||||
def build_webdemo():
|
|
||||||
with gr.Blocks(
|
|
||||||
title=get_lang_text("database_smart_assistant"),
|
|
||||||
# theme=gr.themes.Base(),
|
|
||||||
theme=gr.themes.Default(),
|
|
||||||
css=block_css,
|
|
||||||
) as demo:
|
|
||||||
url_params = gr.JSON(visible=False)
|
|
||||||
(
|
|
||||||
state,
|
|
||||||
chatbot,
|
|
||||||
textbox,
|
|
||||||
send_btn,
|
|
||||||
button_row,
|
|
||||||
parameter_row,
|
|
||||||
) = build_single_model_ui()
|
|
||||||
|
|
||||||
if args.model_list_mode == "once":
|
|
||||||
demo.load(
|
|
||||||
load_demo,
|
|
||||||
[url_params],
|
|
||||||
[
|
|
||||||
state,
|
|
||||||
chatbot,
|
|
||||||
textbox,
|
|
||||||
send_btn,
|
|
||||||
button_row,
|
|
||||||
parameter_row,
|
|
||||||
],
|
|
||||||
_js=get_window_url_params,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
def save_vs_name(vs_name):
|
|
||||||
vector_store_name["vs_name"] = vs_name
|
|
||||||
return vs_name
|
|
||||||
|
|
||||||
|
|
||||||
def knowledge_embedding_store(vs_id, files):
|
|
||||||
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
|
||||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
|
||||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
|
|
||||||
for file in files:
|
|
||||||
filename = os.path.split(file.name)[-1]
|
|
||||||
shutil.move(
|
|
||||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
|
||||||
)
|
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
|
||||||
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
|
||||||
knowledge_type=KnowledgeType.DOCUMENT.value,
|
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
|
||||||
vector_store_config={
|
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
|
||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
knowledge_embedding_client.knowledge_embedding()
|
|
||||||
|
|
||||||
logger.info("knowledge embedding success")
|
|
||||||
return vs_id
|
|
||||||
|
|
||||||
|
|
||||||
def async_db_summery():
|
|
||||||
client = DBSummaryClient()
|
|
||||||
thread = threading.Thread(target=client.init_db_summary)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("in order to avoid chroma db atexit problem")
|
|
||||||
os._exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-new", "--new", action="store_true", help="enable new http mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# old version server config
|
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
||||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
|
||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
|
||||||
|
|
||||||
# init server config
|
|
||||||
args = parser.parse_args()
|
|
||||||
server_init(args)
|
|
||||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_names()
|
|
||||||
demo = build_webdemo()
|
|
||||||
demo.queue(
|
|
||||||
concurrency_count=args.concurrency_count,
|
|
||||||
status_update_rate=10,
|
|
||||||
api_open=False,
|
|
||||||
).launch(
|
|
||||||
server_name=args.host,
|
|
||||||
server_port=args.port,
|
|
||||||
share=args.share,
|
|
||||||
max_threads=200,
|
|
||||||
)
|
|
@@ -58,8 +58,8 @@ class DBSummaryClient:
|
|||||||
)
|
)
|
||||||
embedding.source_embedding()
|
embedding.source_embedding()
|
||||||
for (
|
for (
|
||||||
table_name,
|
table_name,
|
||||||
table_summary,
|
table_summary,
|
||||||
) in db_summary_client.get_table_summary().items():
|
) in db_summary_client.get_table_summary().items():
|
||||||
table_vector_store_config = {
|
table_vector_store_config = {
|
||||||
"vector_store_name": dbname + "_" + table_name + "_ts",
|
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||||
|
@@ -5,12 +5,13 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class RdbmsSummary(DBSummary):
|
class RdbmsSummary(DBSummary):
|
||||||
"""Get mysql summary template."""
|
"""Get mysql summary template."""
|
||||||
|
|
||||||
def __init__(self, name, type):
|
def __init__(self, name, type):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.type = type
|
self.type = type
|
||||||
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||||
self.tables = {}
|
self.tables = {}
|
||||||
self.tables_info = []
|
self.tables_info = []
|
||||||
@@ -177,4 +178,3 @@ class RdbmsIndexSummary(IndexSummary):
|
|||||||
return self.summery_template.format(
|
return self.summery_template.format(
|
||||||
name=self.name, bind_fields=self.bind_fields
|
name=self.name, bind_fields=self.bind_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -43,8 +43,6 @@ pycocoevalcap
|
|||||||
cpm_kernels
|
cpm_kernels
|
||||||
umap-learn
|
umap-learn
|
||||||
notebook
|
notebook
|
||||||
gradio==3.23
|
|
||||||
gradio-client==0.0.8
|
|
||||||
wandb
|
wandb
|
||||||
llama-index==0.5.27
|
llama-index==0.5.27
|
||||||
|
|
||||||
|
@@ -5,11 +5,10 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_view_model import Result
|
from pilot.openapi.api_view_model import Result
|
||||||
from pilot.server.knowledge.request.request import (
|
from pilot.server.knowledge.request.request import (
|
||||||
KnowledgeQueryRequest,
|
KnowledgeQueryRequest,
|
||||||
KnowledgeDocumentRequest,
|
KnowledgeDocumentRequest,
|
||||||
DocumentSyncRequest,
|
|
||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user