fix(core): Use thread pool to replace native threading.Thread (#628)

Close #627
This commit is contained in:
Aries-ckt 2023-09-25 17:45:39 +08:00 committed by GitHub
commit 0e9662a13f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 18 deletions

View File

@ -46,6 +46,7 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller" MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
class BaseComponent(LifeCycle, ABC): class BaseComponent(LifeCycle, ABC):

View File

@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_MANAGE = None self.LOCAL_DB_MANAGE = None

View File

@ -4,7 +4,9 @@ import asyncio
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.component import SystemApp from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect from pilot.connections.base import BaseConnect
@ -76,7 +78,11 @@ class ConnectManager:
+ CFG.LOCAL_DB_HOST + CFG.LOCAL_DB_HOST
+ ":" + ":"
+ str(CFG.LOCAL_DB_PORT), + str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, engine_args={
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"pool_recycle": 3600,
"echo": True,
},
) )
# default_mysql = MySQLConnect.from_uri( # default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://" # "mysql+pymysql://"
@ -208,13 +214,15 @@ class ConnectManager:
db_info.comment, db_info.comment,
) )
# async embedding # async embedding
thread = threading.Thread( executor = CFG.SYSTEM_APP.get_component(
target=self.db_summary_client.db_summary_embedding( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
db_info.db_name, db_info.db_type ).create()
) executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
) )
thread.start()
except Exception as e: except Exception as e:
raise ValueError("Add db connect info error" + str(e)) raise ValueError("Add db connect info error!" + str(e))
return True return True

View File

@ -4,6 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Type from typing import TYPE_CHECKING, Any, Type
from pilot.component import ComponentType, SystemApp from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
@ -22,6 +23,9 @@ def initialize_components(
): ):
from pilot.model.cluster.controller.controller import controller from pilot.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller) system_app.register_instance(controller)
_initialize_embedding_model( _initialize_embedding_model(

View File

@ -124,12 +124,13 @@ def knowledge_init(
def upload(filename: str): def upload(filename: str):
try: try:
logger.info(f"Begin upload document: {filename} to {space.name}") logger.info(f"Begin upload document: {filename} to {space.name}")
return client.document_upload( doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename space.name, filename, KnowledgeType.DOCUMENT.value, filename
) )
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
except Exception as e: except Exception as e:
if skip_wrong_doc: if skip_wrong_doc:
logger.warn(f"Warning: {str(e)}") logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else: else:
raise e raise e
@ -144,5 +145,3 @@ def knowledge_init(
if not doc_ids: if not doc_ids:
logger.warn("Warning: no document to sync") logger.warn("Warning: no document to sync")
return return
logger.info(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@ -9,6 +9,9 @@ from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
) )
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger from pilot.logs import logger
from pilot.server.knowledge.chunk_db import ( from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity, DocumentChunkEntity,
@ -227,10 +230,15 @@ class KnowledgeService:
doc.gmt_modified = datetime.now() doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc) knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings # async doc embeddings
thread = threading.Thread( # thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc) # target=self.async_doc_embedding, args=(client, chunk_docs, doc)
) # )
thread.start() # thread.start()
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}") logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details # save chunk details
chunk_entities = [ chunk_entities = [

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from pilot.component import BaseComponent, ComponentType, SystemApp
class ExecutorFactory(BaseComponent, ABC):
name = ComponentType.EXECUTOR_DEFAULT.value
@abstractmethod
def create(self) -> "Executor":
"""Create executor"""
class DefaultExecutorFactory(ExecutorFactory):
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
super().__init__(system_app)
self._executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=self.name
)
def init_app(self, system_app: SystemApp):
pass
def create(self) -> Executor:
return self._executor

View File

@ -8,7 +8,6 @@ from enum import Enum
import urllib.request import urllib.request
from urllib.parse import urlparse, quote from urllib.parse import urlparse, quote
import re import re
from pip._internal.utils.appdirs import user_cache_dir
import shutil import shutil
from setuptools import find_packages from setuptools import find_packages
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
safe_url, parsed_url = encode_url(package_url) safe_url, parsed_url = encode_url(package_url)
if BUILD_NO_CACHE: if BUILD_NO_CACHE:
return safe_url return safe_url
from pip._internal.utils.appdirs import user_cache_dir
filename = os.path.basename(parsed_url) filename = os.path.basename(parsed_url)
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name) cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
@ -279,7 +281,7 @@ def core_requires():
"importlib-resources==5.12.0", "importlib-resources==5.12.0",
"psutil==5.9.4", "psutil==5.9.4",
"python-dotenv==1.0.0", "python-dotenv==1.0.0",
"colorama", "colorama==0.4.10",
"prettytable", "prettytable",
"cachetools", "cachetools",
] ]