diff --git a/pilot/componet.py b/pilot/component.py similarity index 80% rename from pilot/componet.py rename to pilot/component.py index 73e32490b..7c865479d 100644 --- a/pilot/componet.py +++ b/pilot/component.py @@ -42,16 +42,16 @@ class LifeCycle: pass -class ComponetType(str, Enum): +class ComponentType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" MODEL_CONTROLLER = "dbgpt_model_controller" -class BaseComponet(LifeCycle, ABC): +class BaseComponent(LifeCycle, ABC): """Abstract Base Component class. All custom components should extend this.""" - name = "base_dbgpt_componet" + name = "base_dbgpt_component" def __init__(self, system_app: Optional[SystemApp] = None): if system_app is not None: @@ -67,15 +67,15 @@ class BaseComponet(LifeCycle, ABC): pass -T = TypeVar("T", bound=BaseComponet) +T = TypeVar("T", bound=BaseComponent) class SystemApp(LifeCycle): """Main System Application class that manages the lifecycle and registration of components.""" def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: - self.componets: Dict[ - str, BaseComponet + self.components: Dict[ + str, BaseComponent ] = {} # Dictionary to store registered components. self._asgi_app = asgi_app @@ -84,58 +84,60 @@ class SystemApp(LifeCycle): """Returns the internal ASGI app.""" return self._asgi_app - def register(self, componet: Type[BaseComponet], *args, **kwargs): + def register(self, component: Type[BaseComponent], *args, **kwargs): """Register a new component by its type.""" - instance = componet(self, *args, **kwargs) + instance = component(self, *args, **kwargs) self.register_instance(instance) def register_instance(self, instance: T): """Register an already initialized component.""" name = instance.name - if isinstance(name, ComponetType): + if isinstance(name, ComponentType): name = name.value - if name in self.componets: + if name in self.components: raise RuntimeError( - f"Componse name {name} already exists: {self.componets[name]}" + f"Componse name {name} already exists: {self.components[name]}" ) - logger.info(f"Register componet with name {name} and instance: {instance}") - self.componets[name] = instance + logger.info(f"Register component with name {name} and instance: {instance}") + self.components[name] = instance instance.init_app(self) - def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: + def get_component( + self, name: Union[str, ComponentType], component_type: Type[T] + ) -> T: """Retrieve a registered component by its name and type.""" - if isinstance(name, ComponetType): + if isinstance(name, ComponentType): name = name.value - component = self.componets.get(name) + component = self.components.get(name) if not component: raise ValueError(f"No component found with name {name}") - if not isinstance(component, componet_type): - raise TypeError(f"Component {name} is not of type {componet_type}") + if not isinstance(component, component_type): + raise TypeError(f"Component {name} is not of type {component_type}") return component def before_start(self): """Invoke the before_start hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): v.before_start() async def async_before_start(self): """Asynchronously invoke the before_start hooks for all registered components.""" - tasks = [v.async_before_start() for _, v in self.componets.items()] + tasks = [v.async_before_start() for _, v in self.components.items()] await asyncio.gather(*tasks) def after_start(self): """Invoke the after_start hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): v.after_start() async def async_after_start(self): """Asynchronously invoke the after_start hooks for all registered components.""" - tasks = [v.async_after_start() for _, v in self.componets.items()] + tasks = [v.async_after_start() for _, v in self.components.items()] await asyncio.gather(*tasks) def before_stop(self): """Invoke the before_stop hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): try: v.before_stop() except Exception as e: @@ -143,7 +145,7 @@ class SystemApp(LifeCycle): async def async_before_stop(self): """Asynchronously invoke the before_stop hooks for all registered components.""" - tasks = [v.async_before_stop() for _, v in self.componets.items()] + tasks = [v.async_before_stop() for _, v in self.components.items()] await asyncio.gather(*tasks) def _build(self): diff --git a/pilot/configs/config.py b/pilot/configs/config.py index e068613c1..5d256a8e5 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -189,7 +189,7 @@ class Config(metaclass=Singleton): ### Log level self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") - from pilot.componet import SystemApp + from pilot.component import SystemApp self.SYSTEM_APP: SystemApp = None diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index b5cfbbfdd..534cd36f0 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -4,7 +4,7 @@ import asyncio from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.common.schema import DBType -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.base import BaseConnect diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py index c6f51f4c9..5b49df767 100644 --- a/pilot/embedding_engine/embedding_factory.py +++ b/pilot/embedding_engine/embedding_factory.py @@ -2,13 +2,13 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Type, TYPE_CHECKING -from pilot.componet import BaseComponet +from pilot.component import BaseComponent if TYPE_CHECKING: from langchain.embeddings.base import Embeddings -class EmbeddingFactory(BaseComponet, ABC): +class EmbeddingFactory(BaseComponent, ABC): name = "embedding_factory" @abstractmethod diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 54360e477..e93216929 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -4,7 +4,7 @@ import logging from typing import List from fastapi import APIRouter, FastAPI -from pilot.componet import BaseComponet, ComponetType, SystemApp +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import ModelInstance from pilot.model.parameter import ModelControllerParameters from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry @@ -15,8 +15,8 @@ from pilot.utils.api_utils import ( ) -class BaseModelController(BaseComponet, ABC): - name = ComponetType.MODEL_CONTROLLER +class BaseModelController(BaseComponent, ABC): + name = ComponentType.MODEL_CONTROLLER def init_app(self, system_app: SystemApp): pass diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py index e00455439..ce37755f6 100644 --- a/pilot/model/cluster/manager_base.py +++ b/pilot/model/cluster/manager_base.py @@ -4,7 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable from abc import ABC, abstractmethod from datetime import datetime from concurrent.futures import Future -from pilot.componet import BaseComponet, ComponetType, SystemApp +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest @@ -107,8 +107,8 @@ class WorkerManager(ABC): """Get parameter descriptions of model""" -class WorkerManagerFactory(BaseComponet, ABC): - name = ComponetType.WORKER_MANAGER_FACTORY.value +class WorkerManagerFactory(BaseComponent, ABC): + name = ComponentType.WORKER_MANAGER_FACTORY.value def init_app(self, system_app: SystemApp): pass diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index ef75c5ccf..2576c6a21 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -11,7 +11,7 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional from fastapi import APIRouter, FastAPI from fastapi.responses import StreamingResponse -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.configs.model_config import LOGDIR from pilot.model.base import ( ModelInstance, diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 8af1528cc..111487f00 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -20,7 +20,7 @@ from fastapi.exceptions import RequestValidationError from typing import List import tempfile -from pilot.componet import ComponetType +from pilot.component import ComponentType from pilot.openapi.api_view_model import ( Result, ConversationVo, @@ -96,15 +96,15 @@ def knowledge_list(): def get_model_controller() -> BaseModelController: - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController + controller = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CONTROLLER, BaseModelController ) return controller def get_worker_manager() -> WorkerManager: - worker_manager = CFG.SYSTEM_APP.get_componet( - ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() return worker_manager diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a11d5086f..70805ca80 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -6,7 +6,7 @@ from typing import Any, List, Dict from pilot.configs.config import Config from pilot.configs.model_config import LOGDIR -from pilot.componet import ComponetType +from pilot.component import ComponentType from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory @@ -145,8 +145,8 @@ class BaseChat(ABC): try: from pilot.model.cluster import WorkerManagerFactory - worker_manager = CFG.SYSTEM_APP.get_componet( - ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() async for output in worker_manager.generate_stream(payload): yield output @@ -166,8 +166,8 @@ class BaseChat(ABC): try: from pilot.model.cluster import WorkerManagerFactory - worker_manager = CFG.SYSTEM_APP.get_componet( - ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() model_output = await worker_manager.generate(payload) diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 3856cfb05..3b9bacb99 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat): "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) self.knowledge_embedding_client = EmbeddingEngine( diff --git a/pilot/server/base.py b/pilot/server/base.py index 34b48f599..5ebbc0003 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -6,7 +6,7 @@ from typing import Optional, Any from dataclasses import dataclass, field from pilot.configs.config import Config -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.utils.parameter_utils import BaseParameters diff --git a/pilot/server/componet_configs.py b/pilot/server/component_configs.py similarity index 94% rename from pilot/server/componet_configs.py rename to pilot/server/component_configs.py index 41b5d2ddb..d2dded0a1 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/component_configs.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Type -from pilot.componet import ComponetType, SystemApp +from pilot.component import ComponentType, SystemApp from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.base import WebWerverParameters @@ -14,7 +14,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def initialize_componets( +def initialize_components( param: WebWerverParameters, system_app: SystemApp, embedding_model_name: str, @@ -65,8 +65,8 @@ class RemoteEmbeddingFactory(EmbeddingFactory): if embedding_cls: raise NotImplementedError - worker_manager = self.system_app.get_componet( - ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() # Ignore model_name args return RemoteEmbeddings(self._default_model_name, worker_manager) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index d82086169..3df72ff32 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.server.base import ( server_init, WebWerverParameters, _create_model_start_listener, ) -from pilot.server.componet_configs import initialize_componets +from pilot.server.component_configs import initialize_components from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications @@ -118,7 +118,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): server_init(param, system_app) model_start_listener = _create_model_start_listener(system_app) - initialize_componets(param, system_app, embedding_model_name, embedding_model_path) + initialize_components(param, system_app, embedding_model_name, embedding_model_path) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not param.light: diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index f43333aa1..71b939924 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest): @router.post("/knowledge/{vector_name}/query") def similar_query(space_name: str, query_request: KnowledgeQueryRequest): print(f"Received params: {space_name}, {query_request}") - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) client = EmbeddingEngine( diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 0c04dee3a..fc07040c7 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -205,7 +205,7 @@ class KnowledgeService: chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) client = EmbeddingEngine( diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index f41043601..d4850ec08 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -2,7 +2,7 @@ import json import uuid from pilot.common.schema import DBType -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.configs.config import Config from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, @@ -36,7 +36,7 @@ class DBSummaryClient: from pilot.embedding_engine.embedding_factory import EmbeddingFactory db_summary_client = RdbmsSummary(dbname, db_type) - embedding_factory = self.system_app.get_componet( + embedding_factory = self.system_app.get_component( "embedding_factory", EmbeddingFactory ) embeddings = embedding_factory.create( @@ -94,7 +94,7 @@ class DBSummaryClient: "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) knowledge_embedding_client = EmbeddingEngine( @@ -117,7 +117,7 @@ class DBSummaryClient: "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) knowledge_embedding_client = EmbeddingEngine(