chore: rename component

This commit is contained in:
FangYin Cheng 2023-09-19 10:23:39 +08:00
parent ae34be23fd
commit 2d4d513eb5
16 changed files with 62 additions and 60 deletions

View File

@ -42,16 +42,16 @@ class LifeCycle:
pass
class ComponetType(str, Enum):
class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC):
class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_componet"
name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None:
@ -67,15 +67,15 @@ class BaseComponet(LifeCycle, ABC):
pass
T = TypeVar("T", bound=BaseComponet)
T = TypeVar("T", bound=BaseComponent)
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.componets: Dict[
str, BaseComponet
self.components: Dict[
str, BaseComponent
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
@ -84,58 +84,60 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, componet: Type[BaseComponet], *args, **kwargs):
def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type."""
instance = componet(self, *args, **kwargs)
instance = component(self, *args, **kwargs)
self.register_instance(instance)
def register_instance(self, instance: T):
"""Register an already initialized component."""
name = instance.name
if isinstance(name, ComponetType):
if isinstance(name, ComponentType):
name = name.value
if name in self.componets:
if name in self.components:
raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}"
f"Componse name {name} already exists: {self.components[name]}"
)
logger.info(f"Register componet with name {name} and instance: {instance}")
self.componets[name] = instance
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType):
if isinstance(name, ComponentType):
name = name.value
component = self.componets.get(name)
component = self.components.get(name)
if not component:
raise ValueError(f"No component found with name {name}")
if not isinstance(component, componet_type):
raise TypeError(f"Component {name} is not of type {componet_type}")
if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {component_type}")
return component
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
v.before_start()
async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.componets.items()]
tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def after_start(self):
"""Invoke the after_start hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
v.after_start()
async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.componets.items()]
tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
try:
v.before_stop()
except Exception as e:
@ -143,7 +145,7 @@ class SystemApp(LifeCycle):
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.componets.items()]
tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def _build(self):

View File

@ -189,7 +189,7 @@ class Config(metaclass=Singleton):
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.componet import SystemApp
from pilot.component import SystemApp
self.SYSTEM_APP: SystemApp = None

View File

@ -4,7 +4,7 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect

View File

@ -2,13 +2,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import BaseComponet
from pilot.component import BaseComponent
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingFactory(BaseComponet, ABC):
class EmbeddingFactory(BaseComponent, ABC):
name = "embedding_factory"
@abstractmethod

View File

@ -4,7 +4,7 @@ import logging
from typing import List
from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -15,8 +15,8 @@ from pilot.utils.api_utils import (
)
class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER
class BaseModelController(BaseComponent, ABC):
name = ComponentType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp):
pass

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@ -107,8 +107,8 @@ class WorkerManager(ABC):
"""Get parameter descriptions of model"""
class WorkerManagerFactory(BaseComponet, ABC):
name = ComponetType.WORKER_MANAGER_FACTORY.value
class WorkerManagerFactory(BaseComponent, ABC):
name = ComponentType.WORKER_MANAGER_FACTORY.value
def init_app(self, system_app: SystemApp):
pass

View File

@ -11,7 +11,7 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,

View File

@ -20,7 +20,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List
import tempfile
from pilot.componet import ComponetType
from pilot.component import ComponentType
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
@ -96,15 +96,15 @@ def knowledge_list():
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager

View File

@ -6,7 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.componet import ComponetType
from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -145,8 +145,8 @@ class BaseChat(ABC):
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload):
yield output
@ -166,8 +166,8 @@ class BaseChat(ABC):
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)

View File

@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat):
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
self.knowledge_embedding_client = EmbeddingEngine(

View File

@ -6,7 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.utils.parameter_utils import BaseParameters

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Type
from pilot.componet import ComponetType, SystemApp
from pilot.component import ComponentType, SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def initialize_componets(
def initialize_components(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
@ -65,8 +65,8 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager)

View File

@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.server.base import (
server_init,
WebWerverParameters,
_create_model_start_listener,
)
from pilot.server.componet_configs import initialize_componets
from pilot.server.component_configs import initialize_components
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
@ -118,7 +118,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:

View File

@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(

View File

@ -205,7 +205,7 @@ class KnowledgeService:
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(

View File

@ -2,7 +2,7 @@ import json
import uuid
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -36,7 +36,7 @@ class DBSummaryClient:
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type)
embedding_factory = self.system_app.get_componet(
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
embeddings = embedding_factory.create(
@ -94,7 +94,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
@ -117,7 +117,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(