mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
fix(core): Use thread pool to replace native threading.Thread (#628)
Close #627
This commit is contained in:
commit
0e9662a13f
@ -46,6 +46,7 @@ class ComponentType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||
|
||||
|
||||
class BaseComponent(LifeCycle, ABC):
|
||||
|
@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
||||
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
|
@ -4,7 +4,9 @@ import asyncio
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.component import SystemApp
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.executor_utils import ExecutorFactory
|
||||
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
@ -76,7 +78,11 @@ class ConnectManager:
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT),
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
engine_args={
|
||||
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||
"pool_recycle": 3600,
|
||||
"echo": True,
|
||||
},
|
||||
)
|
||||
# default_mysql = MySQLConnect.from_uri(
|
||||
# "mysql+pymysql://"
|
||||
@ -208,13 +214,15 @@ class ConnectManager:
|
||||
db_info.comment,
|
||||
)
|
||||
# async embedding
|
||||
thread = threading.Thread(
|
||||
target=self.db_summary_client.db_summary_embedding(
|
||||
db_info.db_name, db_info.db_type
|
||||
)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(
|
||||
self.db_summary_client.db_summary_embedding,
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
)
|
||||
thread.start()
|
||||
except Exception as e:
|
||||
raise ValueError("Add db connect info error!" + str(e))
|
||||
raise ValueError("Add db connect info error!" + str(e))
|
||||
|
||||
return True
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
|
||||
from pilot.component import ComponentType, SystemApp
|
||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from pilot.server.base import WebWerverParameters
|
||||
|
||||
@ -22,6 +23,9 @@ def initialize_components(
|
||||
):
|
||||
from pilot.model.cluster.controller.controller import controller
|
||||
|
||||
# Register global default executor factory first
|
||||
system_app.register(DefaultExecutorFactory)
|
||||
|
||||
system_app.register_instance(controller)
|
||||
|
||||
_initialize_embedding_model(
|
||||
|
@ -124,12 +124,13 @@ def knowledge_init(
|
||||
def upload(filename: str):
|
||||
try:
|
||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||
return client.document_upload(
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
|
||||
except Exception as e:
|
||||
if skip_wrong_doc:
|
||||
logger.warn(f"Warning: {str(e)}")
|
||||
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
@ -144,5 +145,3 @@ def knowledge_init(
|
||||
if not doc_ids:
|
||||
logger.warn("Warning: no document to sync")
|
||||
return
|
||||
logger.info(f"Begin sync document: {doc_ids}")
|
||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))
|
||||
|
@ -9,6 +9,9 @@ from pilot.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from pilot.component import ComponentType
|
||||
from pilot.utils.executor_utils import ExecutorFactory
|
||||
|
||||
from pilot.logs import logger
|
||||
from pilot.server.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
@ -227,10 +230,15 @@ class KnowledgeService:
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
# async doc embeddings
|
||||
thread = threading.Thread(
|
||||
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||
)
|
||||
thread.start()
|
||||
# thread = threading.Thread(
|
||||
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||
# )
|
||||
# thread.start()
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
|
26
pilot/utils/executor_utils.py
Normal file
26
pilot/utils/executor_utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
|
||||
class ExecutorFactory(BaseComponent, ABC):
|
||||
name = ComponentType.EXECUTOR_DEFAULT.value
|
||||
|
||||
@abstractmethod
|
||||
def create(self) -> "Executor":
|
||||
"""Create executor"""
|
||||
|
||||
|
||||
class DefaultExecutorFactory(ExecutorFactory):
|
||||
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
|
||||
super().__init__(system_app)
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix=self.name
|
||||
)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
pass
|
||||
|
||||
def create(self) -> Executor:
|
||||
return self._executor
|
6
setup.py
6
setup.py
@ -8,7 +8,6 @@ from enum import Enum
|
||||
import urllib.request
|
||||
from urllib.parse import urlparse, quote
|
||||
import re
|
||||
from pip._internal.utils.appdirs import user_cache_dir
|
||||
import shutil
|
||||
from setuptools import find_packages
|
||||
|
||||
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
|
||||
safe_url, parsed_url = encode_url(package_url)
|
||||
if BUILD_NO_CACHE:
|
||||
return safe_url
|
||||
|
||||
from pip._internal.utils.appdirs import user_cache_dir
|
||||
|
||||
filename = os.path.basename(parsed_url)
|
||||
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
@ -279,7 +281,7 @@ def core_requires():
|
||||
"importlib-resources==5.12.0",
|
||||
"psutil==5.9.4",
|
||||
"python-dotenv==1.0.0",
|
||||
"colorama",
|
||||
"colorama==0.4.10",
|
||||
"prettytable",
|
||||
"cachetools",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user