feat(awel): AWEL supports http trigger

This commit is contained in:
FangYin Cheng 2023-11-20 22:42:04 +08:00
parent 35c1221cdc
commit e67d62a785
20 changed files with 655 additions and 12 deletions

View File

@ -0,0 +1,54 @@
"""AWEL: Simple chat dag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"user_input": "hello"
}'
"""
from typing import Dict
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
from pilot.scene.base_message import ModelMessage
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class TriggerReqBody(BaseModel):
model: str = Field(..., description="Model name")
user_input: str = Field(..., description="User input")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
hist = []
hist.append(ModelMessage.build_human_message(input_value.user_input))
hist = list(h.dict() for h in hist)
params = {
"prompt": input_value.user_input,
"messages": hist,
"model": input_value.model,
"echo": False,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_dag_example") as dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_chat", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = ModelOperator()
# type(out) == ModelOutput
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task

View File

@ -0,0 +1,32 @@
"""AWEL: Simple dag example
Example:
.. code-block:: shell
curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan
"""
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
class TriggerReqBody(BaseModel):
name: str = Field(..., description="User name")
age: int = Field(18, description="User age")
class RequestHandleOperator(MapOperator[TriggerReqBody, str]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> str:
print(f"Receive input value: {input_value}")
return f"Hello, {input_value.name}, your age is {input_value.age}"
with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator()
trigger >> map_node

View File

@ -1,8 +1,17 @@
"""Agentic Workflow Expression Language (AWEL)"""
"""Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""
from pilot.component import SystemApp
from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
@ -28,6 +37,7 @@ from .task.task_impl import (
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
@ -57,4 +67,21 @@ __all__ = [
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()

7
pilot/awel/base.py Normal file
View File

@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -1,14 +1,19 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache
from pilot.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
@ -96,6 +101,7 @@ class DependencyMixin(ABC):
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
@classmethod
def enter_dag(cls, dag) -> None:
@ -138,18 +144,38 @@ class DAGVar:
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
if not cls._system_app:
raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
self,
dag: Optional["DAG"] = None,
node_id: str = None,
node_name: str = None,
system_app: SystemApp = None,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
@ -159,6 +185,10 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str:
return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@ -178,7 +208,7 @@ class DAGNode(DependencyMixin, ABC):
return self._node_name
@property
def dag(self) -> "DAGNode":
def dag(self) -> "DAG":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
@ -254,17 +284,69 @@ class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from pilot.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

93
pilot/awel/dag/loader.py Normal file
View File

@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

@ -14,6 +14,7 @@ from typing import (
)
import functools
from inspect import signature
from pilot.component import SystemApp
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import (
@ -67,6 +68,9 @@ class BaseOperatorMeta(ABCMeta):
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id")
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
@ -80,6 +84,8 @@ class BaseOperatorMeta(ABCMeta):
kwargs["task_id"] = task_id
if not kwargs.get("runner"):
kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
real_obj = func(self, *args, **kwargs)
return real_obj
@ -172,6 +178,6 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return out_ctx.current_task_context.task_output.output_stream
def initialize_awel(runner: WorkflowRunner):
def initialize_runner(runner: WorkflowRunner):
global default_runner
default_runner = runner

View File

@ -237,3 +237,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
def __init__(self, **kwargs) -> None:
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@ -3,7 +3,7 @@ import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager

View File

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -0,0 +1,117 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from starlette.requests import Request
from starlette.responses import Response
from pydantic import BaseModel
import logging
from .base import Trigger
from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self._endpoint = endpoint
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
async def trigger(self) -> None:
pass
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
route_function.__name__ = name
return route_function
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
self._endpoint,
methods=methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
return request

View File

@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
if TYPE_CHECKING:
from fastapi import APIRouter
from pilot.component import SystemApp, BaseComponent, ComponentType
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
class HttpTriggerManager(TriggerManager):
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
) -> None:
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
class DefaultTriggerManager(TriggerManager, BaseComponent):
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger)
else:
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)

View File

@ -54,6 +54,8 @@ class ComponentType(str, Enum):
TRACER = "dbgpt_tracer"
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
class BaseComponent(LifeCycle, ABC):

View File

@ -16,6 +16,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
current_directory = os.getcwd()

View File

@ -7,9 +7,10 @@ from pilot.awel import (
MapOperator,
TransformStreamAbsOperator,
)
from pilot.component import ComponentType
from pilot.awel.operator.base import BaseOperator
from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager
from pilot.model.cluster import WorkerManager, WorkerManagerFactory
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
@ -42,6 +43,10 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for out in self.worker_manager.generate_stream(input_value):
yield out
@ -57,9 +62,9 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
map: Asynchronously processes a single input and returns the model output.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
self.worker_manager = worker_manager
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
async def map(self, input_value: Dict) -> ModelOutput:
"""Process a single input and return the model output.
@ -70,6 +75,10 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
Returns:
ModelOutput: The output from the model.
"""
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return await self.worker_manager.generate(input_value)

View File

@ -117,6 +117,10 @@ class ModelMessage(BaseModel):
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
class Generation(BaseModel):
"""Output of a single generation."""

View File

@ -45,6 +45,7 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
def _initialize_embedding_model(
@ -149,3 +150,10 @@ def _initialize_model_cache(system_app: SystemApp):
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
def _initialize_awel(system_app: SystemApp):
from pilot.awel import initialize_awel
from pilot.configs.model_config import _DAG_DEFINITION_DIR
initialize_awel(system_app, _DAG_DEFINITION_DIR)

View File

@ -1,9 +1,14 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
from typing import Any, List, Optional, Type, Union, Callable, Dict
from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
from collections import OrderedDict
if TYPE_CHECKING:
from pydantic import BaseModel
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
@dataclass
class ParameterDescription:
@ -613,6 +618,64 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
return default_value
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
import pydantic
version = int(pydantic.VERSION.split(".")[0])
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
required_fields = set(schema.get("required", []))
param_descs = []
for field_name, field_schema in schema.get("properties", {}).items():
field = model_cls.model_fields[field_name]
param_type = field_schema.get("type")
if not param_type and "anyOf" in field_schema:
for any_of in field_schema["anyOf"]:
if any_of["type"] != "null":
param_type = any_of["type"]
break
if version >= 2:
default_value = (
field.default
if hasattr(field, "default")
and str(field.default) != "PydanticUndefined"
else None
)
else:
default_value = (
field.default
if not field.allow_none
else (
field.default_factory() if callable(field.default_factory) else None
)
)
description = field_schema.get("description", "")
is_required = field_name in required_fields
valid_values = None
ext_metadata = None
if hasattr(field, "field_info"):
valid_values = (
list(field.field_info.choices)
if hasattr(field.field_info, "choices")
else None
)
ext_metadata = (
field.field_info.extra if hasattr(field.field_info, "extra") else None
)
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
param_desc = ParameterDescription(
param_class=param_class,
param_name=field_name,
param_type=param_type,
default_value=default_value,
description=description,
required=is_required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
param_descs.append(param_desc)
return param_descs
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}