Files
DB-GPT/pilot/component.py
2023-09-19 10:23:39 +08:00

175 lines
5.6 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
import asyncio
# Checking for type hints during runtime
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
def before_start(self):
"""Called before the component starts."""
pass
async def async_before_start(self):
"""Asynchronous version of before_start."""
pass
def after_start(self):
"""Called after the component has started."""
pass
async def async_after_start(self):
"""Asynchronous version of after_start."""
pass
def before_stop(self):
"""Called before the component stops."""
pass
async def async_before_stop(self):
"""Asynchronous version of before_stop."""
pass
class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None:
self.init_app(system_app)
@abstractmethod
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application.
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
pass
T = TypeVar("T", bound=BaseComponent)
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.components: Dict[
str, BaseComponent
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
@property
def app(self) -> Optional["FastAPI"]:
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type."""
instance = component(self, *args, **kwargs)
self.register_instance(instance)
def register_instance(self, instance: T):
"""Register an already initialized component."""
name = instance.name
if isinstance(name, ComponentType):
name = name.value
if name in self.components:
raise RuntimeError(
f"Componse name {name} already exists: {self.components[name]}"
)
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
raise ValueError(f"No component found with name {name}")
if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {component_type}")
return component
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.components.items():
v.before_start()
async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def after_start(self):
"""Invoke the after_start hooks for all registered components."""
for _, v in self.components.items():
v.after_start()
async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
for _, v in self.components.items():
try:
v.before_stop()
except Exception as e:
pass
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def _build(self):
"""Integrate lifecycle events with the internal ASGI app if available."""
if not self.app:
return
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start()
@self.app.on_event("shutdown")
async def shutdown_event():
"""ASGI app shutdown event handler."""
await self.async_before_stop()
self.before_stop()