feat(core): Enhance server request processing performance

This commit is contained in:
FangYin Cheng 2023-10-24 11:52:57 +08:00
parent 185d4366b9
commit 48cd2d6a4a
17 changed files with 178 additions and 53 deletions

View File

@ -1,5 +1,6 @@
from .base import MemoryStoreType from .base import MemoryStoreType
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
CFG = Config() CFG = Config()
@ -18,7 +19,15 @@ class ChatHistory:
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
def get_store_instance(self, chat_session_id): def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
"""New store instance for store chat histories
Args:
chat_session_id (str): conversation session id
Returns:
BaseChatHistoryMemory: Store instance
"""
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)( return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
chat_session_id chat_session_id
) )

View File

@ -19,6 +19,7 @@ from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from concurrent.futures import Executor
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
@ -46,6 +47,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel from pilot.model.base import FlatSupportedModel
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@ -129,6 +131,13 @@ def get_worker_manager() -> WorkerManager:
return worker_manager return worker_manager
def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) @router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list(): async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -158,6 +167,7 @@ async def async_db_summary_embedding(db_name, db_type):
@router.post("/v1/chat/db/test/connect", response_model=Result[bool]) @router.post("/v1/chat/db/test/connect", response_model=Result[bool])
async def test_connect(db_config: DBConfig = Body()): async def test_connect(db_config: DBConfig = Body()):
try: try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config) CFG.LOCAL_DB_MANAGE.test_connect(db_config)
return Result.succ(True) return Result.succ(True)
except Exception as e: except Exception as e:
@ -166,6 +176,7 @@ async def test_connect(db_config: DBConfig = Body()):
@router.post("/v1/chat/db/summary", response_model=Result[bool]) @router.post("/v1/chat/db/summary", response_model=Result[bool])
async def db_summary(db_name: str, db_type: str): async def db_summary(db_name: str, db_type: str):
# TODO Change the synchronous call to the asynchronous call
async_db_summary_embedding(db_name, db_type) async_db_summary_embedding(db_name, db_type)
return Result.succ(True) return Result.succ(True)
@ -185,6 +196,7 @@ async def db_support_types():
async def dialogue_list(user_id: str = None): async def dialogue_list(user_id: str = None):
dialogues: List = [] dialogues: List = []
chat_history_service = ChatHistory() chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call
datas = chat_history_service.get_store_cls().conv_list(user_id) datas = chat_history_service.get_store_cls().conv_list(user_id)
for item in datas: for item in datas:
conv_uid = item.get("conv_uid") conv_uid = item.get("conv_uid")
@ -285,7 +297,7 @@ async def params_load(
select_param=doc_file.filename, select_param=doc_file.filename,
model_name=model_name, model_name=model_name,
) )
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare() resp = await chat.prepare()
### refresh messages ### refresh messages
@ -299,6 +311,7 @@ async def params_load(
async def dialogue_delete(con_uid: str): async def dialogue_delete(con_uid: str):
history_fac = ChatHistory() history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid) history_mem = history_fac.get_store_instance(con_uid)
# TODO Change the synchronous call to the asynchronous call
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
@ -324,10 +337,11 @@ def get_hist_messages(conv_uid: str):
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo]) @router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str): async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}") print(f"dialogue_history_messages:{con_uid}")
# TODO Change the synchronous call to the asynchronous call
return Result.succ(get_hist_messages(con_uid)) return Result.succ(get_hist_messages(con_uid))
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}") logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode: if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value() dialogue.chat_mode = ChatScene.ChatNormal.value()
@ -346,8 +360,14 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
"select_param": dialogue.select_param, "select_param": dialogue.select_param,
"model_name": dialogue.model_name, "model_name": dialogue.model_name,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation( # chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param} # dialogue.chat_mode, **{"chat_param": chat_param}
# )
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
dialogue.chat_mode,
**{"chat_param": chat_param},
) )
return chat return chat
@ -357,7 +377,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL # dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}") logger.info(f"chat_prepare:{dialogue}")
## check conv_uid ## check conv_uid
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
if len(chat.history_message) > 0: if len(chat.history_message) > 0:
return Result.succ(None) return Result.succ(None)
resp = await chat.prepare() resp = await chat.prepare()
@ -369,7 +389,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
print( print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
) )
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
# background_tasks = BackgroundTasks() # background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore) # background_tasks.add_task(release_model_semaphore)
headers = { headers = {

View File

@ -12,6 +12,7 @@ from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pydantic import Extra from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@ -80,6 +81,10 @@ class BaseChat(ABC):
self.current_message.param_type = self.chat_mode.param_types()[0] self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"] self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -92,8 +97,14 @@ class BaseChat(ABC):
raise NotImplementedError("Not supported for this chat type.") raise NotImplementedError("Not supported for this chat type.")
@abstractmethod @abstractmethod
def generate_input_values(self): async def generate_input_values(self) -> Dict:
pass """Generate input to LLM
Please note that you must not perform any blocking operations in this function
Returns:
a dictionary to be formatted by prompt template
"""
def do_action(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response
@ -116,8 +127,8 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response speak_to_user = prompt_define_response
return speak_to_user return speak_to_user
def __call_base(self): async def __call_base(self):
input_values = self.generate_input_values() input_values = await self.generate_input_values()
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
@ -159,7 +170,7 @@ class BaseChat(ABC):
async def stream_call(self): async def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = self.__call_base() payload = await self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
@ -190,7 +201,7 @@ class BaseChat(ABC):
self.memory.append(self.current_message) self.memory.append(self.current_message)
async def nostream_call(self): async def nostream_call(self):
payload = self.__call_base() payload = await self.__call_base()
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -216,14 +227,24 @@ class BaseChat(ABC):
) )
) )
### run ### run
result = self.do_action(prompt_define_response) # result = self.do_action(prompt_define_response)
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
### llm speaker ### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response) speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response( # view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result # speak_to_user, result
# )
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
) )
view_message = view_message.replace("\n", "\\n") view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:

View File

@ -51,7 +51,7 @@ class ChatAgent(BaseChat):
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
def generate_input_values(self): async def generate_input_values(self) -> Dict[str, str]:
input_values = { input_values = {
"user_goal": self.current_user_input, "user_goal": self.current_user_input,
"expand_constraints": self.__list_to_prompt_str( "expand_constraints": self.__list_to_prompt_str(

View File

@ -12,6 +12,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
) )
from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -52,7 +53,7 @@ class ChatDashboard(BaseChat):
data = f.read() data = f.read()
return json.loads(data) return json.loads(data)
def generate_input_values(self): async def generate_input_values(self) -> Dict:
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
@ -60,9 +61,16 @@ class ChatDashboard(BaseChat):
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try: try:
table_infos = client.get_similar_tables( table_infos = await blocking_func_to_async(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k self._executor,
client.get_similar_tables,
self.db_name,
self.current_user_input,
self.top_k,
) )
# table_infos = client.get_similar_tables(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
print("dashboard vector find tables:{}", table_infos) print("dashboard vector find tables:{}", table_infos)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))

View File

@ -62,7 +62,7 @@ class ChatExcel(BaseChat):
# ] # ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
"table_name": self.excel_reader.table_name, "table_name": self.excel_reader.table_name,

View File

@ -1,6 +1,5 @@
import json import json
import os from typing import Any, Dict
from typing import Any
from pilot.scene.base_message import ( from pilot.scene.base_message import (
HumanMessage, HumanMessage,
@ -13,6 +12,7 @@ from pilot.configs.config import Config
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -44,13 +44,15 @@ class ExcelLearning(BaseChat):
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
def generate_input_values(self): async def generate_input_values(self) -> Dict:
colunms, datas = self.excel_reader.get_sample_data() # colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async(
self._executor, self.excel_reader.get_sample_data
)
copy_datas = datas.copy()
datas.insert(0, colunms) datas.insert(0, colunms)
input_values = { input_values = {
"data_example": json.dumps( "data_example": json.dumps(copy_datas, cls=DateTimeEncoder),
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
),
} }
return input_values return input_values

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -38,7 +39,7 @@ class ChatWithDbAutoExecute(BaseChat):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200 self.top_k: int = 200
def generate_input_values(self): async def generate_input_values(self) -> Dict:
""" """
generate input values generate input values
""" """
@ -47,19 +48,27 @@ class ChatWithDbAutoExecute(BaseChat):
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try: try:
table_infos = client.get_db_summary( # table_infos = client.get_db_summary(
dbname=self.db_name, # dbname=self.db_name,
query=self.current_user_input, # query=self.current_user_input,
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info()
if not table_infos: if not table_infos:
table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
table_infos = await blocking_func_to_async(
# table_infos = self.database.table_simple_info() self._executor, self.database.table_simple_info
)
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -38,7 +39,7 @@ class ChatWithDbQA(BaseChat):
else len(self.tables) else len(self.tables)
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
table_info = "" table_info = ""
dialect = "mysql" dialect = "mysql"
try: try:
@ -48,12 +49,22 @@ class ChatWithDbQA(BaseChat):
if self.db_name: if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try: try:
table_infos = client.get_db_summary( # table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k # dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
self.top_k,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
table_infos = await blocking_func_to_async(
self._executor, self.database.table_simple_info
)
# table_infos = self.database.table_simple_info() # table_infos = self.database.table_simple_info()
dialect = self.database.dialect dialect = self.database.dialect

View File

@ -50,7 +50,7 @@ class ChatWithPlugin(BaseChat):
self.plugins_prompt_generator self.plugins_prompt_generator
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"constraints": self.__list_to_prompt_str( "constraints": self.__list_to_prompt_str(

View File

@ -1,3 +1,4 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
@ -30,7 +31,7 @@ class InnerChatDBSummary(BaseChat):
self.db_input = db_select self.db_input = db_select
self.db_summary = db_summary self.db_summary = db_summary
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"db_input": self.db_input, "db_input": self.db_input,
"db_profile_summary": self.db_summary, "db_profile_summary": self.db_summary,

View File

@ -12,6 +12,7 @@ from pilot.configs.model_config import (
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async
CFG = Config() CFG = Config()
@ -65,7 +66,7 @@ class ChatKnowledge(BaseChat):
self.prompt_template.template_is_strict = False self.prompt_template.template_is_strict = False
async def stream_call(self): async def stream_call(self):
input_values = self.generate_input_values() input_values = await self.generate_input_values()
# Source of knowledge file # Source of knowledge file
relations = input_values.get("relations") relations = input_values.get("relations")
last_output = None last_output = None
@ -85,12 +86,18 @@ class ChatKnowledge(BaseChat):
) )
yield last_output yield last_output
def generate_input_values(self): async def generate_input_values(self) -> Dict:
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"] self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.knowledge_embedding_client.similar_search( # docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k # self.current_user_input, self.top_k
# )
docs = await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search,
self.current_user_input,
self.top_k,
) )
if not docs: if not docs:
raise ValueError( raise ValueError(

View File

@ -21,7 +21,7 @@ class ChatNormal(BaseChat):
chat_param=chat_param, chat_param=chat_param,
) )
def generate_input_values(self): async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values

View File

@ -1,5 +1,8 @@
from typing import Callable, Awaitable, Any
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.component import BaseComponent, ComponentType, SystemApp
@ -24,3 +27,34 @@ class DefaultExecutorFactory(ExecutorFactory):
def create(self) -> Executor: def create(self) -> Executor:
return self._executor return self._executor
BlockingFunction = Callable[..., Any]
async def blocking_func_to_async(
executor: Executor, func: BlockingFunction, *args, **kwargs
):
"""Run a potentially blocking function within an executor.
Args:
executor (Executor): The concurrent.futures.Executor to run the function within.
func (ApplyFunction): The callable function, which should be a synchronous function.
It should accept any number and type of arguments and return an asynchronous coroutine.
*args (Any): Any additional arguments to pass to the function.
**kwargs (Any): Other arguments to pass to the function
Returns:
Any: The result of the function's execution.
Raises:
ValueError: If the provided function 'func' is an asynchronous coroutine function.
This function allows you to execute a potentially blocking function within an executor.
It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine.
"""
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
loop = asyncio.get_event_loop()
sync_function_noargs = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, sync_function_noargs)

View File

@ -35,15 +35,14 @@ clone_repositories() {
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b-int4 git clone https://huggingface.co/THUDM/chatglm2-6b
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
rm -rf /root/DB-GPT/models/chatglm2-6b-int4/.git rm -rf /root/DB-GPT/models/chatglm2-6b/.git
} }
install_dbgpt_packages() { install_dbgpt_packages() {
conda activate dbgpt && cd /root/DB-GPT && pip install -e . && cp .env.template .env conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b-int4/' .env cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env
} }
clean_up() { clean_up() {

View File

@ -317,6 +317,8 @@ def core_requires():
# TODO move transformers to default # TODO move transformers to default
"transformers>=4.31.0", "transformers>=4.31.0",
"alembic==1.12.0", "alembic==1.12.0",
# for excel
"openpyxl",
] ]
@ -361,6 +363,8 @@ def quantization_requires():
) )
pkgs = [f"bitsandbytes @ {local_pkg}"] pkgs = [f"bitsandbytes @ {local_pkg}"]
print(pkgs) print(pkgs)
# For chatglm2-6b-int4
pkgs += ["cpm_kernels"]
setup_spec.extras["quantization"] = pkgs setup_spec.extras["quantization"] = pkgs