mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
feat(core): Enhance server request processing performance
This commit is contained in:
parent
185d4366b9
commit
48cd2d6a4a
@ -1,5 +1,6 @@
|
|||||||
from .base import MemoryStoreType
|
from .base import MemoryStoreType
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -18,7 +19,15 @@ class ChatHistory:
|
|||||||
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
|
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
|
||||||
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
|
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
|
||||||
|
|
||||||
def get_store_instance(self, chat_session_id):
|
def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
|
||||||
|
"""New store instance for store chat histories
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_session_id (str): conversation session id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseChatHistoryMemory: Store instance
|
||||||
|
"""
|
||||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
|
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
|
||||||
chat_session_id
|
chat_session_id
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import List
|
from typing import List
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from concurrent.futures import Executor
|
||||||
|
|
||||||
from pilot.component import ComponentType
|
from pilot.component import ComponentType
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
@ -46,6 +47,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
|
|||||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||||
from pilot.model.base import FlatSupportedModel
|
from pilot.model.base import FlatSupportedModel
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -129,6 +131,13 @@ def get_worker_manager() -> WorkerManager:
|
|||||||
return worker_manager
|
return worker_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_executor() -> Executor:
|
||||||
|
"""Get the global default executor"""
|
||||||
|
return CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||||
async def db_connect_list():
|
async def db_connect_list():
|
||||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||||
@ -158,6 +167,7 @@ async def async_db_summary_embedding(db_name, db_type):
|
|||||||
@router.post("/v1/chat/db/test/connect", response_model=Result[bool])
|
@router.post("/v1/chat/db/test/connect", response_model=Result[bool])
|
||||||
async def test_connect(db_config: DBConfig = Body()):
|
async def test_connect(db_config: DBConfig = Body()):
|
||||||
try:
|
try:
|
||||||
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
|
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
|
||||||
return Result.succ(True)
|
return Result.succ(True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -166,6 +176,7 @@ async def test_connect(db_config: DBConfig = Body()):
|
|||||||
|
|
||||||
@router.post("/v1/chat/db/summary", response_model=Result[bool])
|
@router.post("/v1/chat/db/summary", response_model=Result[bool])
|
||||||
async def db_summary(db_name: str, db_type: str):
|
async def db_summary(db_name: str, db_type: str):
|
||||||
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
async_db_summary_embedding(db_name, db_type)
|
async_db_summary_embedding(db_name, db_type)
|
||||||
return Result.succ(True)
|
return Result.succ(True)
|
||||||
|
|
||||||
@ -185,6 +196,7 @@ async def db_support_types():
|
|||||||
async def dialogue_list(user_id: str = None):
|
async def dialogue_list(user_id: str = None):
|
||||||
dialogues: List = []
|
dialogues: List = []
|
||||||
chat_history_service = ChatHistory()
|
chat_history_service = ChatHistory()
|
||||||
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
datas = chat_history_service.get_store_cls().conv_list(user_id)
|
datas = chat_history_service.get_store_cls().conv_list(user_id)
|
||||||
for item in datas:
|
for item in datas:
|
||||||
conv_uid = item.get("conv_uid")
|
conv_uid = item.get("conv_uid")
|
||||||
@ -285,7 +297,7 @@ async def params_load(
|
|||||||
select_param=doc_file.filename,
|
select_param=doc_file.filename,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = await get_chat_instance(dialogue)
|
||||||
resp = await chat.prepare()
|
resp = await chat.prepare()
|
||||||
|
|
||||||
### refresh messages
|
### refresh messages
|
||||||
@ -299,6 +311,7 @@ async def params_load(
|
|||||||
async def dialogue_delete(con_uid: str):
|
async def dialogue_delete(con_uid: str):
|
||||||
history_fac = ChatHistory()
|
history_fac = ChatHistory()
|
||||||
history_mem = history_fac.get_store_instance(con_uid)
|
history_mem = history_fac.get_store_instance(con_uid)
|
||||||
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
history_mem.delete()
|
history_mem.delete()
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
@ -324,10 +337,11 @@ def get_hist_messages(conv_uid: str):
|
|||||||
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
||||||
async def dialogue_history_messages(con_uid: str):
|
async def dialogue_history_messages(con_uid: str):
|
||||||
print(f"dialogue_history_messages:{con_uid}")
|
print(f"dialogue_history_messages:{con_uid}")
|
||||||
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
return Result.succ(get_hist_messages(con_uid))
|
return Result.succ(get_hist_messages(con_uid))
|
||||||
|
|
||||||
|
|
||||||
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||||
logger.info(f"get_chat_instance:{dialogue}")
|
logger.info(f"get_chat_instance:{dialogue}")
|
||||||
if not dialogue.chat_mode:
|
if not dialogue.chat_mode:
|
||||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||||
@ -346,8 +360,14 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
|||||||
"select_param": dialogue.select_param,
|
"select_param": dialogue.select_param,
|
||||||
"model_name": dialogue.model_name,
|
"model_name": dialogue.model_name,
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(
|
# chat: BaseChat = CHAT_FACTORY.get_implementation(
|
||||||
dialogue.chat_mode, **{"chat_param": chat_param}
|
# dialogue.chat_mode, **{"chat_param": chat_param}
|
||||||
|
# )
|
||||||
|
chat: BaseChat = await blocking_func_to_async(
|
||||||
|
get_executor(),
|
||||||
|
CHAT_FACTORY.get_implementation,
|
||||||
|
dialogue.chat_mode,
|
||||||
|
**{"chat_param": chat_param},
|
||||||
)
|
)
|
||||||
return chat
|
return chat
|
||||||
|
|
||||||
@ -357,7 +377,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
|||||||
# dialogue.model_name = CFG.LLM_MODEL
|
# dialogue.model_name = CFG.LLM_MODEL
|
||||||
logger.info(f"chat_prepare:{dialogue}")
|
logger.info(f"chat_prepare:{dialogue}")
|
||||||
## check conv_uid
|
## check conv_uid
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = await get_chat_instance(dialogue)
|
||||||
if len(chat.history_message) > 0:
|
if len(chat.history_message) > 0:
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
resp = await chat.prepare()
|
resp = await chat.prepare()
|
||||||
@ -369,7 +389,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
print(
|
print(
|
||||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||||
)
|
)
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = await get_chat_instance(dialogue)
|
||||||
# background_tasks = BackgroundTasks()
|
# background_tasks = BackgroundTasks()
|
||||||
# background_tasks.add_task(release_model_semaphore)
|
# background_tasks.add_task(release_model_semaphore)
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -12,6 +12,7 @@ from pilot.prompts.prompt_new import PromptTemplate
|
|||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.utils import get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
|
|
||||||
@ -80,6 +81,10 @@ class BaseChat(ABC):
|
|||||||
self.current_message.param_type = self.chat_mode.param_types()[0]
|
self.current_message.param_type = self.chat_mode.param_types()[0]
|
||||||
self.current_message.param_value = chat_param["select_param"]
|
self.current_message.param_value = chat_param["select_param"]
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
|
# The executor to submit blocking function
|
||||||
|
self._executor = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -92,8 +97,14 @@ class BaseChat(ABC):
|
|||||||
raise NotImplementedError("Not supported for this chat type.")
|
raise NotImplementedError("Not supported for this chat type.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
pass
|
"""Generate input to LLM
|
||||||
|
|
||||||
|
Please note that you must not perform any blocking operations in this function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dictionary to be formatted by prompt template
|
||||||
|
"""
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
@ -116,8 +127,8 @@ class BaseChat(ABC):
|
|||||||
speak_to_user = prompt_define_response
|
speak_to_user = prompt_define_response
|
||||||
return speak_to_user
|
return speak_to_user
|
||||||
|
|
||||||
def __call_base(self):
|
async def __call_base(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = await self.generate_input_values()
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
@ -159,7 +170,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
# TODO Retry when server connection error
|
# TODO Retry when server connection error
|
||||||
payload = self.__call_base()
|
payload = await self.__call_base()
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
@ -190,7 +201,7 @@ class BaseChat(ABC):
|
|||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
async def nostream_call(self):
|
async def nostream_call(self):
|
||||||
payload = self.__call_base()
|
payload = await self.__call_base()
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
@ -216,14 +227,24 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
### run
|
### run
|
||||||
result = self.do_action(prompt_define_response)
|
# result = self.do_action(prompt_define_response)
|
||||||
|
result = await blocking_func_to_async(
|
||||||
|
self._executor, self.do_action, prompt_define_response
|
||||||
|
)
|
||||||
|
|
||||||
### llm speaker
|
### llm speaker
|
||||||
speak_to_user = self.get_llm_speak(prompt_define_response)
|
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||||
|
|
||||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
# view_message = self.prompt_template.output_parser.parse_view_response(
|
||||||
speak_to_user, result
|
# speak_to_user, result
|
||||||
|
# )
|
||||||
|
view_message = await blocking_func_to_async(
|
||||||
|
self._executor,
|
||||||
|
self.prompt_template.output_parser.parse_view_response,
|
||||||
|
speak_to_user,
|
||||||
|
result,
|
||||||
)
|
)
|
||||||
|
|
||||||
view_message = view_message.replace("\n", "\\n")
|
view_message = view_message.replace("\n", "\\n")
|
||||||
self.current_message.add_view_message(view_message)
|
self.current_message.add_view_message(view_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -51,7 +51,7 @@ class ChatAgent(BaseChat):
|
|||||||
|
|
||||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict[str, str]:
|
||||||
input_values = {
|
input_values = {
|
||||||
"user_goal": self.current_user_input,
|
"user_goal": self.current_user_input,
|
||||||
"expand_constraints": self.__list_to_prompt_str(
|
"expand_constraints": self.__list_to_prompt_str(
|
||||||
|
@ -12,6 +12,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
|||||||
)
|
)
|
||||||
from pilot.scene.chat_dashboard.prompt import prompt
|
from pilot.scene.chat_dashboard.prompt import prompt
|
||||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||||
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ class ChatDashboard(BaseChat):
|
|||||||
data = f.read()
|
data = f.read()
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
try:
|
try:
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -60,9 +61,16 @@ class ChatDashboard(BaseChat):
|
|||||||
|
|
||||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||||
try:
|
try:
|
||||||
table_infos = client.get_similar_tables(
|
table_infos = await blocking_func_to_async(
|
||||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
self._executor,
|
||||||
|
client.get_similar_tables,
|
||||||
|
self.db_name,
|
||||||
|
self.current_user_input,
|
||||||
|
self.top_k,
|
||||||
)
|
)
|
||||||
|
# table_infos = client.get_similar_tables(
|
||||||
|
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||||
|
# )
|
||||||
print("dashboard vector find tables:{}", table_infos)
|
print("dashboard vector find tables:{}", table_infos)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("db summary find error!" + str(e))
|
print("db summary find error!" + str(e))
|
||||||
|
@ -62,7 +62,7 @@ class ChatExcel(BaseChat):
|
|||||||
# ]
|
# ]
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"user_input": self.current_user_input,
|
"user_input": self.current_user_input,
|
||||||
"table_name": self.excel_reader.table_name,
|
"table_name": self.excel_reader.table_name,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
from typing import Any, Dict
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pilot.scene.base_message import (
|
from pilot.scene.base_message import (
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
@ -13,6 +12,7 @@ from pilot.configs.config import Config
|
|||||||
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||||
from pilot.json_utils.utilities import DateTimeEncoder
|
from pilot.json_utils.utilities import DateTimeEncoder
|
||||||
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -44,13 +44,15 @@ class ExcelLearning(BaseChat):
|
|||||||
if parent_mode:
|
if parent_mode:
|
||||||
self.current_message.chat_mode = parent_mode.value()
|
self.current_message.chat_mode = parent_mode.value()
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
colunms, datas = self.excel_reader.get_sample_data()
|
# colunms, datas = self.excel_reader.get_sample_data()
|
||||||
|
colunms, datas = await blocking_func_to_async(
|
||||||
|
self._executor, self.excel_reader.get_sample_data
|
||||||
|
)
|
||||||
|
copy_datas = datas.copy()
|
||||||
datas.insert(0, colunms)
|
datas.insert(0, colunms)
|
||||||
|
|
||||||
input_values = {
|
input_values = {
|
||||||
"data_example": json.dumps(
|
"data_example": json.dumps(copy_datas, cls=DateTimeEncoder),
|
||||||
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
self.top_k: int = 200
|
self.top_k: int = 200
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
"""
|
"""
|
||||||
generate input values
|
generate input values
|
||||||
"""
|
"""
|
||||||
@ -47,19 +48,27 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Could not import DBSummaryClient. ")
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||||
|
table_infos = None
|
||||||
try:
|
try:
|
||||||
table_infos = client.get_db_summary(
|
# table_infos = client.get_db_summary(
|
||||||
dbname=self.db_name,
|
# dbname=self.db_name,
|
||||||
query=self.current_user_input,
|
# query=self.current_user_input,
|
||||||
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||||
|
# )
|
||||||
|
table_infos = await blocking_func_to_async(
|
||||||
|
self._executor,
|
||||||
|
client.get_db_summary,
|
||||||
|
self.db_name,
|
||||||
|
self.current_user_input,
|
||||||
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("db summary find error!" + str(e))
|
print("db summary find error!" + str(e))
|
||||||
table_infos = self.database.table_simple_info()
|
|
||||||
if not table_infos:
|
if not table_infos:
|
||||||
table_infos = self.database.table_simple_info()
|
# table_infos = self.database.table_simple_info()
|
||||||
|
table_infos = await blocking_func_to_async(
|
||||||
# table_infos = self.database.table_simple_info()
|
self._executor, self.database.table_simple_info
|
||||||
|
)
|
||||||
|
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
|
@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||||
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
else len(self.tables)
|
else len(self.tables)
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
table_info = ""
|
table_info = ""
|
||||||
dialect = "mysql"
|
dialect = "mysql"
|
||||||
try:
|
try:
|
||||||
@ -48,12 +49,22 @@ class ChatWithDbQA(BaseChat):
|
|||||||
if self.db_name:
|
if self.db_name:
|
||||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||||
try:
|
try:
|
||||||
table_infos = client.get_db_summary(
|
# table_infos = client.get_db_summary(
|
||||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||||
|
# )
|
||||||
|
table_infos = await blocking_func_to_async(
|
||||||
|
self._executor,
|
||||||
|
client.get_db_summary,
|
||||||
|
self.db_name,
|
||||||
|
self.current_user_input,
|
||||||
|
self.top_k,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("db summary find error!" + str(e))
|
print("db summary find error!" + str(e))
|
||||||
table_infos = self.database.table_simple_info()
|
# table_infos = self.database.table_simple_info()
|
||||||
|
table_infos = await blocking_func_to_async(
|
||||||
|
self._executor, self.database.table_simple_info
|
||||||
|
)
|
||||||
|
|
||||||
# table_infos = self.database.table_simple_info()
|
# table_infos = self.database.table_simple_info()
|
||||||
dialect = self.database.dialect
|
dialect = self.database.dialect
|
||||||
|
@ -50,7 +50,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
self.plugins_prompt_generator
|
self.plugins_prompt_generator
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"constraints": self.__list_to_prompt_str(
|
"constraints": self.__list_to_prompt_str(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Dict
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -30,7 +31,7 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
self.db_input = db_select
|
self.db_input = db_select
|
||||||
self.db_summary = db_summary
|
self.db_summary = db_summary
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"db_input": self.db_input,
|
"db_input": self.db_input,
|
||||||
"db_profile_summary": self.db_summary,
|
"db_profile_summary": self.db_summary,
|
||||||
|
@ -12,6 +12,7 @@ from pilot.configs.model_config import (
|
|||||||
|
|
||||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
from pilot.server.knowledge.service import KnowledgeService
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
self.prompt_template.template_is_strict = False
|
self.prompt_template.template_is_strict = False
|
||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = await self.generate_input_values()
|
||||||
# Source of knowledge file
|
# Source of knowledge file
|
||||||
relations = input_values.get("relations")
|
relations = input_values.get("relations")
|
||||||
last_output = None
|
last_output = None
|
||||||
@ -85,12 +86,18 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
yield last_output
|
yield last_output
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
if self.space_context:
|
if self.space_context:
|
||||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
# docs = self.knowledge_embedding_client.similar_search(
|
||||||
self.current_user_input, self.top_k
|
# self.current_user_input, self.top_k
|
||||||
|
# )
|
||||||
|
docs = await blocking_func_to_async(
|
||||||
|
self._executor,
|
||||||
|
self.knowledge_embedding_client.similar_search,
|
||||||
|
self.current_user_input,
|
||||||
|
self.top_k,
|
||||||
)
|
)
|
||||||
if not docs:
|
if not docs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -21,7 +21,7 @@ class ChatNormal(BaseChat):
|
|||||||
chat_param=chat_param,
|
chat_param=chat_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {"input": self.current_user_input}
|
input_values = {"input": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
from typing import Callable, Awaitable, Any
|
||||||
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
|
||||||
@ -24,3 +27,34 @@ class DefaultExecutorFactory(ExecutorFactory):
|
|||||||
|
|
||||||
def create(self) -> Executor:
|
def create(self) -> Executor:
|
||||||
return self._executor
|
return self._executor
|
||||||
|
|
||||||
|
|
||||||
|
BlockingFunction = Callable[..., Any]
|
||||||
|
|
||||||
|
|
||||||
|
async def blocking_func_to_async(
|
||||||
|
executor: Executor, func: BlockingFunction, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""Run a potentially blocking function within an executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
executor (Executor): The concurrent.futures.Executor to run the function within.
|
||||||
|
func (ApplyFunction): The callable function, which should be a synchronous function.
|
||||||
|
It should accept any number and type of arguments and return an asynchronous coroutine.
|
||||||
|
*args (Any): Any additional arguments to pass to the function.
|
||||||
|
**kwargs (Any): Other arguments to pass to the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the function's execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided function 'func' is an asynchronous coroutine function.
|
||||||
|
|
||||||
|
This function allows you to execute a potentially blocking function within an executor.
|
||||||
|
It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine.
|
||||||
|
"""
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
raise ValueError(f"The function {func} is not blocking function")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
sync_function_noargs = partial(func, *args, **kwargs)
|
||||||
|
return await loop.run_in_executor(executor, sync_function_noargs)
|
||||||
|
@ -35,15 +35,14 @@ clone_repositories() {
|
|||||||
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
|
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
|
||||||
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
|
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
|
||||||
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
||||||
git clone https://huggingface.co/THUDM/chatglm2-6b-int4
|
git clone https://huggingface.co/THUDM/chatglm2-6b
|
||||||
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
|
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
|
||||||
rm -rf /root/DB-GPT/models/chatglm2-6b-int4/.git
|
rm -rf /root/DB-GPT/models/chatglm2-6b/.git
|
||||||
}
|
}
|
||||||
|
|
||||||
install_dbgpt_packages() {
|
install_dbgpt_packages() {
|
||||||
conda activate dbgpt && cd /root/DB-GPT && pip install -e . && cp .env.template .env
|
conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
|
||||||
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b-int4/' .env
|
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clean_up() {
|
clean_up() {
|
||||||
|
4
setup.py
4
setup.py
@ -317,6 +317,8 @@ def core_requires():
|
|||||||
# TODO move transformers to default
|
# TODO move transformers to default
|
||||||
"transformers>=4.31.0",
|
"transformers>=4.31.0",
|
||||||
"alembic==1.12.0",
|
"alembic==1.12.0",
|
||||||
|
# for excel
|
||||||
|
"openpyxl",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -361,6 +363,8 @@ def quantization_requires():
|
|||||||
)
|
)
|
||||||
pkgs = [f"bitsandbytes @ {local_pkg}"]
|
pkgs = [f"bitsandbytes @ {local_pkg}"]
|
||||||
print(pkgs)
|
print(pkgs)
|
||||||
|
# For chatglm2-6b-int4
|
||||||
|
pkgs += ["cpm_kernels"]
|
||||||
setup_spec.extras["quantization"] = pkgs
|
setup_spec.extras["quantization"] = pkgs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user