fix(core): Use thread pool to replace native threading.Thread

This commit is contained in:
FangYin Cheng 2023-09-25 17:33:51 +08:00
parent d49e0eec8b
commit 3d5e0f4028
7 changed files with 63 additions and 16 deletions

View File

@ -46,6 +46,7 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
class BaseComponent(LifeCycle, ABC):

View File

@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_MANAGE = None

View File

@ -4,7 +4,9 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.component import SystemApp
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
@ -76,7 +78,11 @@ class ConnectManager:
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
engine_args={
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"pool_recycle": 3600,
"echo": True,
},
)
# default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://"
@ -208,13 +214,15 @@ class ConnectManager:
db_info.comment,
)
# async embedding
thread = threading.Thread(
target=self.db_summary_client.db_summary_embedding(
db_info.db_name, db_info.db_type
)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
)
thread.start()
except Exception as e:
raise ValueError("Add db connect info error" + str(e))
raise ValueError("Add db connect info error!" + str(e))
return True

View File

@ -4,6 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Type
from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@ -22,6 +23,9 @@ def initialize_components(
):
from pilot.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller)
_initialize_embedding_model(

View File

@ -124,12 +124,13 @@ def knowledge_init(
def upload(filename: str):
try:
logger.info(f"Begin upload document: {filename} to {space.name}")
return client.document_upload(
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
except Exception as e:
if skip_wrong_doc:
logger.warn(f"Warning: {str(e)}")
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else:
raise e
@ -144,5 +145,3 @@ def knowledge_init(
if not doc_ids:
logger.warn("Warning: no document to sync")
return
logger.info(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@ -9,6 +9,9 @@ from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@ -227,10 +230,15 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start()
# thread = threading.Thread(
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
# )
# thread.start()
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from pilot.component import BaseComponent, ComponentType, SystemApp
class ExecutorFactory(BaseComponent, ABC):
name = ComponentType.EXECUTOR_DEFAULT.value
@abstractmethod
def create(self) -> "Executor":
"""Create executor"""
class DefaultExecutorFactory(ExecutorFactory):
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
super().__init__(system_app)
self._executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=self.name
)
def init_app(self, system_app: SystemApp):
pass
def create(self) -> Executor:
return self._executor