fix(core): Use thread pool to replace native threading.Thread (#628)

Close #627
This commit is contained in:
Aries-ckt 2023-09-25 17:45:39 +08:00 committed by GitHub
commit 0e9662a13f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 18 deletions

View File

@ -46,6 +46,7 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
class BaseComponent(LifeCycle, ABC):

View File

@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_MANAGE = None

View File

@ -4,7 +4,9 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.component import SystemApp
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
@ -76,7 +78,11 @@ class ConnectManager:
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
engine_args={
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"pool_recycle": 3600,
"echo": True,
},
)
# default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://"
@ -208,13 +214,15 @@ class ConnectManager:
db_info.comment,
)
# async embedding
thread = threading.Thread(
target=self.db_summary_client.db_summary_embedding(
db_info.db_name, db_info.db_type
)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
)
thread.start()
except Exception as e:
raise ValueError("Add db connect info error" + str(e))
raise ValueError("Add db connect info error!" + str(e))
return True

View File

@ -4,6 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Type
from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@ -22,6 +23,9 @@ def initialize_components(
):
from pilot.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller)
_initialize_embedding_model(

View File

@ -124,12 +124,13 @@ def knowledge_init(
def upload(filename: str):
try:
logger.info(f"Begin upload document: {filename} to {space.name}")
return client.document_upload(
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
except Exception as e:
if skip_wrong_doc:
logger.warn(f"Warning: {str(e)}")
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else:
raise e
@ -144,5 +145,3 @@ def knowledge_init(
if not doc_ids:
logger.warn("Warning: no document to sync")
return
logger.info(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@ -9,6 +9,9 @@ from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@ -227,10 +230,15 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start()
# thread = threading.Thread(
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
# )
# thread.start()
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from pilot.component import BaseComponent, ComponentType, SystemApp
class ExecutorFactory(BaseComponent, ABC):
name = ComponentType.EXECUTOR_DEFAULT.value
@abstractmethod
def create(self) -> "Executor":
"""Create executor"""
class DefaultExecutorFactory(ExecutorFactory):
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
super().__init__(system_app)
self._executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=self.name
)
def init_app(self, system_app: SystemApp):
pass
def create(self) -> Executor:
return self._executor

View File

@ -8,7 +8,6 @@ from enum import Enum
import urllib.request
from urllib.parse import urlparse, quote
import re
from pip._internal.utils.appdirs import user_cache_dir
import shutil
from setuptools import find_packages
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
safe_url, parsed_url = encode_url(package_url)
if BUILD_NO_CACHE:
return safe_url
from pip._internal.utils.appdirs import user_cache_dir
filename = os.path.basename(parsed_url)
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
os.makedirs(cache_dir, exist_ok=True)
@ -279,7 +281,7 @@ def core_requires():
"importlib-resources==5.12.0",
"psutil==5.9.4",
"python-dotenv==1.0.0",
"colorama",
"colorama==0.4.10",
"prettytable",
"cachetools",
]