feat(core): Enhance server request processing performance (#722)

Close #720 
**Others**: 
- Fix chat tracer no spans bug
- Modify AutoDL setup script
This commit is contained in:
Aries-ckt 2023-10-24 17:23:29 +08:00 committed by GitHub
commit 96a48675e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 201 additions and 67 deletions

View File

@ -1,7 +1,6 @@
from pandas import DataFrame from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd import pandas as pd
import uuid import uuid
import os import os

View File

@ -12,7 +12,7 @@ from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from ..common.schema import PluginStorageType from ..common.schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger("agent_hub") logger = logging.getLogger(__name__)
Default_User = "default" Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git" DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = "" TEMP_PLUGIN_PATH = ""

View File

@ -9,6 +9,7 @@ import requests
import git import git
import threading import threading
import datetime import datetime
import logging
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
@ -19,7 +20,8 @@ from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
logger = logging.getLogger(__name__)
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:

View File

@ -20,7 +20,7 @@ from urllib.parse import quote
from pilot.configs.config import Config from pilot.configs.config import Config
logger = logging.getLogger("meta_data") logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
default_db_path = os.path.join(os.getcwd(), "meta_data") default_db_path = os.path.join(os.getcwd(), "meta_data")

View File

@ -249,7 +249,8 @@ def remove_color_codes(s: str) -> str:
return ansi_escape.sub("", s) return ansi_escape.sub("", s)
logger: Logger = Logger() # Remove current logger
# logger: Logger = Logger()
def print_assistant_thoughts( def print_assistant_thoughts(

View File

@ -1,5 +1,6 @@
from .base import MemoryStoreType from .base import MemoryStoreType
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
CFG = Config() CFG = Config()
@ -18,7 +19,15 @@ class ChatHistory:
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
def get_store_instance(self, chat_session_id): def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
"""New store instance for store chat histories
Args:
chat_session_id (str): conversation session id
Returns:
BaseChatHistoryMemory: Store instance
"""
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)( return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
chat_session_id chat_session_id
) )

View File

@ -14,7 +14,7 @@ from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
from pilot.memory.chat_history.base import MemoryStoreType from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config() CFG = Config()
logger = logging.getLogger("db_chat_history") logger = logging.getLogger(__name__)
class DbHistoryMemory(BaseChatHistoryMemory): class DbHistoryMemory(BaseChatHistoryMemory):

View File

@ -19,6 +19,7 @@ from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from concurrent.futures import Executor
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
@ -46,6 +47,8 @@ from pilot.summary.db_summary_client import DBSummaryClient
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@ -129,6 +132,13 @@ def get_worker_manager() -> WorkerManager:
return worker_manager return worker_manager
def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) @router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list(): async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -158,6 +168,7 @@ async def async_db_summary_embedding(db_name, db_type):
@router.post("/v1/chat/db/test/connect", response_model=Result[bool]) @router.post("/v1/chat/db/test/connect", response_model=Result[bool])
async def test_connect(db_config: DBConfig = Body()): async def test_connect(db_config: DBConfig = Body()):
try: try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config) CFG.LOCAL_DB_MANAGE.test_connect(db_config)
return Result.succ(True) return Result.succ(True)
except Exception as e: except Exception as e:
@ -166,6 +177,7 @@ async def test_connect(db_config: DBConfig = Body()):
@router.post("/v1/chat/db/summary", response_model=Result[bool]) @router.post("/v1/chat/db/summary", response_model=Result[bool])
async def db_summary(db_name: str, db_type: str): async def db_summary(db_name: str, db_type: str):
# TODO Change the synchronous call to the asynchronous call
async_db_summary_embedding(db_name, db_type) async_db_summary_embedding(db_name, db_type)
return Result.succ(True) return Result.succ(True)
@ -185,6 +197,7 @@ async def db_support_types():
async def dialogue_list(user_id: str = None): async def dialogue_list(user_id: str = None):
dialogues: List = [] dialogues: List = []
chat_history_service = ChatHistory() chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call
datas = chat_history_service.get_store_cls().conv_list(user_id) datas = chat_history_service.get_store_cls().conv_list(user_id)
for item in datas: for item in datas:
conv_uid = item.get("conv_uid") conv_uid = item.get("conv_uid")
@ -285,7 +298,7 @@ async def params_load(
select_param=doc_file.filename, select_param=doc_file.filename,
model_name=model_name, model_name=model_name,
) )
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare() resp = await chat.prepare()
### refresh messages ### refresh messages
@ -299,6 +312,7 @@ async def params_load(
async def dialogue_delete(con_uid: str): async def dialogue_delete(con_uid: str):
history_fac = ChatHistory() history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid) history_mem = history_fac.get_store_instance(con_uid)
# TODO Change the synchronous call to the asynchronous call
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
@ -324,10 +338,11 @@ def get_hist_messages(conv_uid: str):
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo]) @router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str): async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}") print(f"dialogue_history_messages:{con_uid}")
# TODO Change the synchronous call to the asynchronous call
return Result.succ(get_hist_messages(con_uid)) return Result.succ(get_hist_messages(con_uid))
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}") logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode: if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value() dialogue.chat_mode = ChatScene.ChatNormal.value()
@ -346,8 +361,14 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
"select_param": dialogue.select_param, "select_param": dialogue.select_param,
"model_name": dialogue.model_name, "model_name": dialogue.model_name,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation( # chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param} # dialogue.chat_mode, **{"chat_param": chat_param}
# )
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
dialogue.chat_mode,
**{"chat_param": chat_param},
) )
return chat return chat
@ -357,7 +378,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL # dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}") logger.info(f"chat_prepare:{dialogue}")
## check conv_uid ## check conv_uid
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
if len(chat.history_message) > 0: if len(chat.history_message) > 0:
return Result.succ(None) return Result.succ(None)
resp = await chat.prepare() resp = await chat.prepare()
@ -369,7 +390,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
print( print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
) )
chat: BaseChat = get_chat_instance(dialogue) with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = await get_chat_instance(dialogue)
# background_tasks = BackgroundTasks() # background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore) # background_tasks.add_task(release_model_semaphore)
headers = { headers = {
@ -420,8 +444,9 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana
async def no_stream_generator(chat): async def no_stream_generator(chat):
msg = await chat.nostream_call() with root_tracer.start_span("no_stream_generator"):
yield f"data: {msg}\n\n" msg = await chat.nostream_call()
yield f"data: {msg}\n\n"
async def stream_generator(chat, incremental: bool, model_name: str): async def stream_generator(chat, incremental: bool, model_name: str):
@ -438,6 +463,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
Yields: Yields:
_type_: streaming responses _type_: streaming responses
""" """
span = root_tracer.start_span("stream_generator")
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}" stream_id = f"chatcmpl-{str(uuid.uuid1())}"
@ -463,6 +489,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
await asyncio.sleep(0.02) await asyncio.sleep(0.02)
if incremental: if incremental:
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
span.end()
def message2Vo(message: dict, order, model_name) -> MessageVo: def message2Vo(message: dict, order, model_name) -> MessageVo:

View File

@ -12,6 +12,7 @@ from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pydantic import Extra from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@ -80,6 +81,10 @@ class BaseChat(ABC):
self.current_message.param_type = self.chat_mode.param_types()[0] self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"] self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -92,8 +97,14 @@ class BaseChat(ABC):
raise NotImplementedError("Not supported for this chat type.") raise NotImplementedError("Not supported for this chat type.")
@abstractmethod @abstractmethod
def generate_input_values(self): async def generate_input_values(self) -> Dict:
pass """Generate input to LLM
Please note that you must not perform any blocking operations in this function
Returns:
a dictionary to be formatted by prompt template
"""
def do_action(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response
@ -116,8 +127,8 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response speak_to_user = prompt_define_response
return speak_to_user return speak_to_user
def __call_base(self): async def __call_base(self):
input_values = self.generate_input_values() input_values = await self.generate_input_values()
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
@ -159,7 +170,7 @@ class BaseChat(ABC):
async def stream_call(self): async def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = self.__call_base() payload = await self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
@ -190,7 +201,7 @@ class BaseChat(ABC):
self.memory.append(self.current_message) self.memory.append(self.current_message)
async def nostream_call(self): async def nostream_call(self):
payload = self.__call_base() payload = await self.__call_base()
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -216,14 +227,24 @@ class BaseChat(ABC):
) )
) )
### run ### run
result = self.do_action(prompt_define_response) # result = self.do_action(prompt_define_response)
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
### llm speaker ### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response) speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response( # view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result # speak_to_user, result
# )
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
) )
view_message = view_message.replace("\n", "\\n") view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:

View File

@ -51,7 +51,7 @@ class ChatAgent(BaseChat):
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
def generate_input_values(self): async def generate_input_values(self) -> Dict[str, str]:
input_values = { input_values = {
"user_goal": self.current_user_input, "user_goal": self.current_user_input,
"expand_constraints": self.__list_to_prompt_str( "expand_constraints": self.__list_to_prompt_str(

View File

@ -1,11 +1,6 @@
import json import json
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple): class PluginAction(NamedTuple):

View File

@ -12,6 +12,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
) )
from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -52,7 +53,7 @@ class ChatDashboard(BaseChat):
data = f.read() data = f.read()
return json.loads(data) return json.loads(data)
def generate_input_values(self): async def generate_input_values(self) -> Dict:
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
@ -60,9 +61,16 @@ class ChatDashboard(BaseChat):
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try: try:
table_infos = client.get_similar_tables( table_infos = await blocking_func_to_async(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k self._executor,
client.get_similar_tables,
self.db_name,
self.current_user_input,
self.top_k,
) )
# table_infos = client.get_similar_tables(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
print("dashboard vector find tables:{}", table_infos) print("dashboard vector find tables:{}", table_infos)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))

View File

@ -62,7 +62,7 @@ class ChatExcel(BaseChat):
# ] # ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
"table_name": self.excel_reader.table_name, "table_name": self.excel_reader.table_name,

View File

@ -1,6 +1,5 @@
import json import json
import os from typing import Any, Dict
from typing import Any
from pilot.scene.base_message import ( from pilot.scene.base_message import (
HumanMessage, HumanMessage,
@ -13,6 +12,7 @@ from pilot.configs.config import Config
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -44,13 +44,15 @@ class ExcelLearning(BaseChat):
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
def generate_input_values(self): async def generate_input_values(self) -> Dict:
colunms, datas = self.excel_reader.get_sample_data() # colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async(
self._executor, self.excel_reader.get_sample_data
)
copy_datas = datas.copy()
datas.insert(0, colunms) datas.insert(0, colunms)
input_values = { input_values = {
"data_example": json.dumps( "data_example": json.dumps(copy_datas, cls=DateTimeEncoder),
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
),
} }
return input_values return input_values

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -38,7 +39,7 @@ class ChatWithDbAutoExecute(BaseChat):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200 self.top_k: int = 200
def generate_input_values(self): async def generate_input_values(self) -> Dict:
""" """
generate input values generate input values
""" """
@ -47,19 +48,27 @@ class ChatWithDbAutoExecute(BaseChat):
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try: try:
table_infos = client.get_db_summary( # table_infos = client.get_db_summary(
dbname=self.db_name, # dbname=self.db_name,
query=self.current_user_input, # query=self.current_user_input,
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info()
if not table_infos: if not table_infos:
table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
table_infos = await blocking_func_to_async(
# table_infos = self.database.table_simple_info() self._executor, self.database.table_simple_info
)
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -38,7 +39,7 @@ class ChatWithDbQA(BaseChat):
else len(self.tables) else len(self.tables)
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
table_info = "" table_info = ""
dialect = "mysql" dialect = "mysql"
try: try:
@ -48,12 +49,22 @@ class ChatWithDbQA(BaseChat):
if self.db_name: if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try: try:
table_infos = client.get_db_summary( # table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k # dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
self.top_k,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
table_infos = await blocking_func_to_async(
self._executor, self.database.table_simple_info
)
# table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
dialect = self.database.dialect dialect = self.database.dialect

View File

@ -50,7 +50,7 @@ class ChatWithPlugin(BaseChat):
self.plugins_prompt_generator self.plugins_prompt_generator
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"constraints": self.__list_to_prompt_str( "constraints": self.__list_to_prompt_str(

View File

@ -1,5 +1,6 @@
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.utils.tracer import root_tracer
class ChatFactory(metaclass=Singleton): class ChatFactory(metaclass=Singleton):
@ -20,7 +21,11 @@ class ChatFactory(metaclass=Singleton):
implementation = None implementation = None
for cls in chat_classes: for cls in chat_classes:
if cls.chat_scene == chat_mode: if cls.chat_scene == chat_mode:
implementation = cls(**kwargs) metadata = {"cls": str(cls), "params": kwargs}
with root_tracer.start_span(
"get_implementation_of_chat", metadata=metadata
):
implementation = cls(**kwargs)
if implementation == None: if implementation == None:
raise Exception(f"Invalid implementation name:{chat_mode}") raise Exception(f"Invalid implementation name:{chat_mode}")
return implementation return implementation

View File

@ -1,3 +1,4 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
@ -30,7 +31,7 @@ class InnerChatDBSummary(BaseChat):
self.db_input = db_select self.db_input = db_select
self.db_summary = db_summary self.db_summary = db_summary
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"db_input": self.db_input, "db_input": self.db_input,
"db_profile_summary": self.db_summary, "db_profile_summary": self.db_summary,

View File

@ -12,6 +12,7 @@ from pilot.configs.model_config import (
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -65,7 +66,7 @@ class ChatKnowledge(BaseChat):
self.prompt_template.template_is_strict = False self.prompt_template.template_is_strict = False
async def stream_call(self): async def stream_call(self):
input_values = self.generate_input_values() input_values = await self.generate_input_values()
# Source of knowledge file # Source of knowledge file
relations = input_values.get("relations") relations = input_values.get("relations")
last_output = None last_output = None
@ -85,12 +86,18 @@ class ChatKnowledge(BaseChat):
) )
yield last_output yield last_output
def generate_input_values(self): async def generate_input_values(self) -> Dict:
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"] self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.knowledge_embedding_client.similar_search( # docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k # self.current_user_input, self.top_k
# )
docs = await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search,
self.current_user_input,
self.top_k,
) )
if not docs: if not docs:
raise ValueError( raise ValueError(

View File

@ -21,7 +21,7 @@ class ChatNormal(BaseChat):
chat_param=chat_param, chat_param=chat_param,
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values

View File

@ -1,5 +1,8 @@
from typing import Callable, Awaitable, Any
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.component import BaseComponent, ComponentType, SystemApp
@ -24,3 +27,34 @@ class DefaultExecutorFactory(ExecutorFactory):
def create(self) -> Executor: def create(self) -> Executor:
return self._executor return self._executor
BlockingFunction = Callable[..., Any]
async def blocking_func_to_async(
executor: Executor, func: BlockingFunction, *args, **kwargs
):
"""Run a potentially blocking function within an executor.
Args:
executor (Executor): The concurrent.futures.Executor to run the function within.
func (ApplyFunction): The callable function, which should be a synchronous function.
It should accept any number and type of arguments and return an asynchronous coroutine.
*args (Any): Any additional arguments to pass to the function.
**kwargs (Any): Other arguments to pass to the function
Returns:
Any: The result of the function's execution.
Raises:
ValueError: If the provided function 'func' is an asynchronous coroutine function.
This function allows you to execute a potentially blocking function within an executor.
It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine.
"""
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
loop = asyncio.get_event_loop()
sync_function_noargs = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, sync_function_noargs)

View File

@ -35,15 +35,14 @@ clone_repositories() {
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b-int4 git clone https://huggingface.co/THUDM/chatglm2-6b
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
rm -rf /root/DB-GPT/models/chatglm2-6b-int4/.git rm -rf /root/DB-GPT/models/chatglm2-6b/.git
} }
install_dbgpt_packages() { install_dbgpt_packages() {
conda activate dbgpt && cd /root/DB-GPT && pip install -e . && cp .env.template .env conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b-int4/' .env cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env
} }
clean_up() { clean_up() {

View File

@ -317,6 +317,8 @@ def core_requires():
# TODO move transformers to default # TODO move transformers to default
"transformers>=4.31.0", "transformers>=4.31.0",
"alembic==1.12.0", "alembic==1.12.0",
# for excel
"openpyxl",
] ]
@ -361,6 +363,8 @@ def quantization_requires():
) )
pkgs = [f"bitsandbytes @ {local_pkg}"] pkgs = [f"bitsandbytes @ {local_pkg}"]
print(pkgs) print(pkgs)
# For chatglm2-6b-int4
pkgs += ["cpm_kernels"]
setup_spec.extras["quantization"] = pkgs setup_spec.extras["quantization"] = pkgs