chore: rename component

This commit is contained in:
FangYin Cheng 2023-09-19 10:23:39 +08:00
parent ae34be23fd
commit 2d4d513eb5
16 changed files with 62 additions and 60 deletions

View File

@ -42,16 +42,16 @@ class LifeCycle:
pass pass
class ComponetType(str, Enum): class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller" MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC): class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this.""" """Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_componet" name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None): def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None: if system_app is not None:
@ -67,15 +67,15 @@ class BaseComponet(LifeCycle, ABC):
pass pass
T = TypeVar("T", bound=BaseComponet) T = TypeVar("T", bound=BaseComponent)
class SystemApp(LifeCycle): class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components.""" """Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.componets: Dict[ self.components: Dict[
str, BaseComponet str, BaseComponent
] = {} # Dictionary to store registered components. ] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app self._asgi_app = asgi_app
@ -84,58 +84,60 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app.""" """Returns the internal ASGI app."""
return self._asgi_app return self._asgi_app
def register(self, componet: Type[BaseComponet], *args, **kwargs): def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type.""" """Register a new component by its type."""
instance = componet(self, *args, **kwargs) instance = component(self, *args, **kwargs)
self.register_instance(instance) self.register_instance(instance)
def register_instance(self, instance: T): def register_instance(self, instance: T):
"""Register an already initialized component.""" """Register an already initialized component."""
name = instance.name name = instance.name
if isinstance(name, ComponetType): if isinstance(name, ComponentType):
name = name.value name = name.value
if name in self.componets: if name in self.components:
raise RuntimeError( raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}" f"Componse name {name} already exists: {self.components[name]}"
) )
logger.info(f"Register componet with name {name} and instance: {instance}") logger.info(f"Register component with name {name} and instance: {instance}")
self.componets[name] = instance self.components[name] = instance
instance.init_app(self) instance.init_app(self)
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
) -> T:
"""Retrieve a registered component by its name and type.""" """Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType): if isinstance(name, ComponentType):
name = name.value name = name.value
component = self.componets.get(name) component = self.components.get(name)
if not component: if not component:
raise ValueError(f"No component found with name {name}") raise ValueError(f"No component found with name {name}")
if not isinstance(component, componet_type): if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {componet_type}") raise TypeError(f"Component {name} is not of type {component_type}")
return component return component
def before_start(self): def before_start(self):
"""Invoke the before_start hooks for all registered components.""" """Invoke the before_start hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
v.before_start() v.before_start()
async def async_before_start(self): async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components.""" """Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.componets.items()] tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def after_start(self): def after_start(self):
"""Invoke the after_start hooks for all registered components.""" """Invoke the after_start hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
v.after_start() v.after_start()
async def async_after_start(self): async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components.""" """Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.componets.items()] tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def before_stop(self): def before_stop(self):
"""Invoke the before_stop hooks for all registered components.""" """Invoke the before_stop hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
try: try:
v.before_stop() v.before_stop()
except Exception as e: except Exception as e:
@ -143,7 +145,7 @@ class SystemApp(LifeCycle):
async def async_before_stop(self): async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components.""" """Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.componets.items()] tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def _build(self): def _build(self):

View File

@ -189,7 +189,7 @@ class Config(metaclass=Singleton):
### Log level ### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.componet import SystemApp from pilot.component import SystemApp
self.SYSTEM_APP: SystemApp = None self.SYSTEM_APP: SystemApp = None

View File

@ -4,7 +4,7 @@ import asyncio
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect from pilot.connections.base import BaseConnect

View File

@ -2,13 +2,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING
from pilot.componet import BaseComponet from pilot.component import BaseComponent
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
class EmbeddingFactory(BaseComponet, ABC): class EmbeddingFactory(BaseComponent, ABC):
name = "embedding_factory" name = "embedding_factory"
@abstractmethod @abstractmethod

View File

@ -4,7 +4,7 @@ import logging
from typing import List from typing import List
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -15,8 +15,8 @@ from pilot.utils.api_utils import (
) )
class BaseModelController(BaseComponet, ABC): class BaseModelController(BaseComponent, ABC):
name = ComponetType.MODEL_CONTROLLER name = ComponentType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp): def init_app(self, system_app: SystemApp):
pass pass

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from concurrent.futures import Future from concurrent.futures import Future
from pilot.componet import BaseComponet, ComponetType, SystemApp from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@ -107,8 +107,8 @@ class WorkerManager(ABC):
"""Get parameter descriptions of model""" """Get parameter descriptions of model"""
class WorkerManagerFactory(BaseComponet, ABC): class WorkerManagerFactory(BaseComponent, ABC):
name = ComponetType.WORKER_MANAGER_FACTORY.value name = ComponentType.WORKER_MANAGER_FACTORY.value
def init_app(self, system_app: SystemApp): def init_app(self, system_app: SystemApp):
pass pass

View File

@ -11,7 +11,7 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.model.base import ( from pilot.model.base import (
ModelInstance, ModelInstance,

View File

@ -20,7 +20,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from pilot.componet import ComponetType from pilot.component import ComponentType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
@ -96,15 +96,15 @@ def knowledge_list():
def get_model_controller() -> BaseModelController: def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_componet( controller = CFG.SYSTEM_APP.get_component(
ComponetType.MODEL_CONTROLLER, BaseModelController ComponentType.MODEL_CONTROLLER, BaseModelController
) )
return controller return controller
def get_worker_manager() -> WorkerManager: def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_componet( worker_manager = CFG.SYSTEM_APP.get_component(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
return worker_manager return worker_manager

View File

@ -6,7 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.componet import ComponetType from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -145,8 +145,8 @@ class BaseChat(ABC):
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet( worker_manager = CFG.SYSTEM_APP.get_component(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
async for output in worker_manager.generate_stream(payload): async for output in worker_manager.generate_stream(payload):
yield output yield output
@ -166,8 +166,8 @@ class BaseChat(ABC):
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet( worker_manager = CFG.SYSTEM_APP.get_component(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
model_output = await worker_manager.generate(payload) model_output = await worker_manager.generate(payload)

View File

@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat):
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(

View File

@ -6,7 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.utils.parameter_utils import BaseParameters from pilot.utils.parameter_utils import BaseParameters

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Type from typing import TYPE_CHECKING, Any, Type
from pilot.componet import ComponetType, SystemApp from pilot.component import ComponentType, SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def initialize_componets( def initialize_components(
param: WebWerverParameters, param: WebWerverParameters,
system_app: SystemApp, system_app: SystemApp,
embedding_model_name: str, embedding_model_name: str,
@ -65,8 +65,8 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
if embedding_cls: if embedding_cls:
raise NotImplementedError raise NotImplementedError
worker_manager = self.system_app.get_componet( worker_manager = self.system_app.get_component(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
# Ignore model_name args # Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager) return RemoteEmbeddings(self._default_model_name, worker_manager)

View File

@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH)
import signal import signal
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.server.base import ( from pilot.server.base import (
server_init, server_init,
WebWerverParameters, WebWerverParameters,
_create_model_start_listener, _create_model_start_listener,
) )
from pilot.server.componet_configs import initialize_componets from pilot.server.component_configs import initialize_components
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications from fastapi import FastAPI, applications
@ -118,7 +118,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
server_init(param, system_app) server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app) model_start_listener = _create_model_start_listener(system_app)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path) initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light: if not param.light:

View File

@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query") @router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest): def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}") print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
client = EmbeddingEngine( client = EmbeddingEngine(

View File

@ -205,7 +205,7 @@ class KnowledgeService:
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
client = EmbeddingEngine( client = EmbeddingEngine(

View File

@ -2,7 +2,7 @@ import json
import uuid import uuid
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -36,7 +36,7 @@ class DBSummaryClient:
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type) db_summary_client = RdbmsSummary(dbname, db_type)
embedding_factory = self.system_app.get_componet( embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
embeddings = embedding_factory.create( embeddings = embedding_factory.create(
@ -94,7 +94,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(
@ -117,7 +117,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(