mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
fix(core): Use thread pool to replace native threading.Thread (#628)
Close #627
This commit is contained in:
commit
0e9662a13f
@ -46,6 +46,7 @@ class ComponentType(str, Enum):
|
|||||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||||
|
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||||
|
|
||||||
|
|
||||||
class BaseComponent(LifeCycle, ABC):
|
class BaseComponent(LifeCycle, ABC):
|
||||||
|
@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
|
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
||||||
|
|
||||||
self.LOCAL_DB_MANAGE = None
|
self.LOCAL_DB_MANAGE = None
|
||||||
|
|
||||||
|
@ -4,7 +4,9 @@ import asyncio
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp, ComponentType
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory
|
||||||
|
|
||||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||||
from pilot.connections.base import BaseConnect
|
from pilot.connections.base import BaseConnect
|
||||||
|
|
||||||
@ -76,7 +78,11 @@ class ConnectManager:
|
|||||||
+ CFG.LOCAL_DB_HOST
|
+ CFG.LOCAL_DB_HOST
|
||||||
+ ":"
|
+ ":"
|
||||||
+ str(CFG.LOCAL_DB_PORT),
|
+ str(CFG.LOCAL_DB_PORT),
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
engine_args={
|
||||||
|
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
"echo": True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# default_mysql = MySQLConnect.from_uri(
|
# default_mysql = MySQLConnect.from_uri(
|
||||||
# "mysql+pymysql://"
|
# "mysql+pymysql://"
|
||||||
@ -208,13 +214,15 @@ class ConnectManager:
|
|||||||
db_info.comment,
|
db_info.comment,
|
||||||
)
|
)
|
||||||
# async embedding
|
# async embedding
|
||||||
thread = threading.Thread(
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
target=self.db_summary_client.db_summary_embedding(
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
db_info.db_name, db_info.db_type
|
).create()
|
||||||
)
|
executor.submit(
|
||||||
|
self.db_summary_client.db_summary_embedding,
|
||||||
|
db_info.db_name,
|
||||||
|
db_info.db_type,
|
||||||
)
|
)
|
||||||
thread.start()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Add db connect info error!" + str(e))
|
raise ValueError("Add db connect info error!" + str(e))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any, Type
|
from typing import TYPE_CHECKING, Any, Type
|
||||||
|
|
||||||
from pilot.component import ComponentType, SystemApp
|
from pilot.component import ComponentType, SystemApp
|
||||||
|
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
|
|
||||||
@ -22,6 +23,9 @@ def initialize_components(
|
|||||||
):
|
):
|
||||||
from pilot.model.cluster.controller.controller import controller
|
from pilot.model.cluster.controller.controller import controller
|
||||||
|
|
||||||
|
# Register global default executor factory first
|
||||||
|
system_app.register(DefaultExecutorFactory)
|
||||||
|
|
||||||
system_app.register_instance(controller)
|
system_app.register_instance(controller)
|
||||||
|
|
||||||
_initialize_embedding_model(
|
_initialize_embedding_model(
|
||||||
|
@ -124,12 +124,13 @@ def knowledge_init(
|
|||||||
def upload(filename: str):
|
def upload(filename: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||||
return client.document_upload(
|
doc_id = client.document_upload(
|
||||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||||
)
|
)
|
||||||
|
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if skip_wrong_doc:
|
if skip_wrong_doc:
|
||||||
logger.warn(f"Warning: {str(e)}")
|
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -144,5 +145,3 @@ def knowledge_init(
|
|||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
logger.warn("Warning: no document to sync")
|
logger.warn("Warning: no document to sync")
|
||||||
return
|
return
|
||||||
logger.info(f"Begin sync document: {doc_ids}")
|
|
||||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))
|
|
||||||
|
@ -9,6 +9,9 @@ from pilot.configs.model_config import (
|
|||||||
EMBEDDING_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
)
|
)
|
||||||
|
from pilot.component import ComponentType
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory
|
||||||
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
@ -227,10 +230,15 @@ class KnowledgeService:
|
|||||||
doc.gmt_modified = datetime.now()
|
doc.gmt_modified = datetime.now()
|
||||||
knowledge_document_dao.update_knowledge_document(doc)
|
knowledge_document_dao.update_knowledge_document(doc)
|
||||||
# async doc embeddings
|
# async doc embeddings
|
||||||
thread = threading.Thread(
|
# thread = threading.Thread(
|
||||||
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||||
)
|
# )
|
||||||
thread.start()
|
# thread.start()
|
||||||
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||||
|
|
||||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||||
# save chunk details
|
# save chunk details
|
||||||
chunk_entities = [
|
chunk_entities = [
|
||||||
|
26
pilot/utils/executor_utils.py
Normal file
26
pilot/utils/executor_utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
|
||||||
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorFactory(BaseComponent, ABC):
|
||||||
|
name = ComponentType.EXECUTOR_DEFAULT.value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self) -> "Executor":
|
||||||
|
"""Create executor"""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultExecutorFactory(ExecutorFactory):
|
||||||
|
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
|
||||||
|
super().__init__(system_app)
|
||||||
|
self._executor = ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers, thread_name_prefix=self.name
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_app(self, system_app: SystemApp):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(self) -> Executor:
|
||||||
|
return self._executor
|
6
setup.py
6
setup.py
@ -8,7 +8,6 @@ from enum import Enum
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
from urllib.parse import urlparse, quote
|
from urllib.parse import urlparse, quote
|
||||||
import re
|
import re
|
||||||
from pip._internal.utils.appdirs import user_cache_dir
|
|
||||||
import shutil
|
import shutil
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
|
||||||
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
|
|||||||
safe_url, parsed_url = encode_url(package_url)
|
safe_url, parsed_url = encode_url(package_url)
|
||||||
if BUILD_NO_CACHE:
|
if BUILD_NO_CACHE:
|
||||||
return safe_url
|
return safe_url
|
||||||
|
|
||||||
|
from pip._internal.utils.appdirs import user_cache_dir
|
||||||
|
|
||||||
filename = os.path.basename(parsed_url)
|
filename = os.path.basename(parsed_url)
|
||||||
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
@ -279,7 +281,7 @@ def core_requires():
|
|||||||
"importlib-resources==5.12.0",
|
"importlib-resources==5.12.0",
|
||||||
"psutil==5.9.4",
|
"psutil==5.9.4",
|
||||||
"python-dotenv==1.0.0",
|
"python-dotenv==1.0.0",
|
||||||
"colorama",
|
"colorama==0.4.10",
|
||||||
"prettytable",
|
"prettytable",
|
||||||
"cachetools",
|
"cachetools",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user