diff --git a/examples/awel/simple_chat_dag_example.py b/examples/awel/simple_chat_dag_example.py new file mode 100644 index 000000000..4da382d58 --- /dev/null +++ b/examples/awel/simple_chat_dag_example.py @@ -0,0 +1,54 @@ +"""AWEL: Simple chat dag example + + Example: + + .. code-block:: shell + + curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "user_input": "hello" + }' +""" +from typing import Dict +from pydantic import BaseModel, Field + +from pilot.awel import DAG, HttpTrigger, MapOperator +from pilot.scene.base_message import ModelMessage +from pilot.model.base import ModelOutput +from pilot.model.operator.model_operator import ModelOperator + + +class TriggerReqBody(BaseModel): + model: str = Field(..., description="Model name") + user_input: str = Field(..., description="User input") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + hist = [] + hist.append(ModelMessage.build_human_message(input_value.user_input)) + hist = list(h.dict() for h in hist) + params = { + "prompt": input_value.user_input, + "messages": hist, + "model": input_value.model, + "echo": False, + } + print(f"Receive input value: {input_value}") + return params + + +with DAG("dbgpt_awel_simple_dag_example") as dag: + # Receive http request and trigger dag to run. + trigger = HttpTrigger( + "/examples/simple_chat", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + model_task = ModelOperator() + # type(out) == ModelOutput + model_parse_task = MapOperator(lambda out: out.to_dict()) + trigger >> request_handle_task >> model_task >> model_parse_task diff --git a/examples/awel/simple_dag_example.py b/examples/awel/simple_dag_example.py new file mode 100644 index 000000000..0bdf0dff7 --- /dev/null +++ b/examples/awel/simple_dag_example.py @@ -0,0 +1,32 @@ +"""AWEL: Simple dag example + + Example: + + .. code-block:: shell + + curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan + +""" +from pydantic import BaseModel, Field + +from pilot.awel import DAG, HttpTrigger, MapOperator + + +class TriggerReqBody(BaseModel): + name: str = Field(..., description="User name") + age: int = Field(18, description="User age") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, str]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> str: + print(f"Receive input value: {input_value}") + return f"Hello, {input_value.name}, your age is {input_value.age}" + + +with DAG("simple_dag_example") as dag: + trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody) + map_node = RequestHandleOperator() + trigger >> map_node diff --git a/examples/awel/simple_rag_example.py b/examples/awel/simple_rag_example.py new file mode 100644 index 000000000..78c08ac2f --- /dev/null +++ b/examples/awel/simple_rag_example.py @@ -0,0 +1,70 @@ +"""AWEL: Simple rag example + + Example: + + .. code-block:: shell + + curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \ + -H "Content-Type: application/json" -d '{ + "conv_uid": "36f0e992-8825-11ee-8638-0242ac150003", + "model_name": "proxyllm", + "chat_mode": "chat_knowledge", + "user_input": "What is DB-GPT?", + "select_param": "default" + }' + +""" + +from pilot.awel import HttpTrigger, DAG, MapOperator +from pilot.scene.operator._experimental import ( + ChatContext, + PromptManagerOperator, + ChatHistoryStorageOperator, + ChatHistoryOperator, + EmbeddingEngingOperator, + BaseChatOperator, +) +from pilot.scene.base import ChatScene +from pilot.openapi.api_view_model import ConversationVo +from pilot.model.base import ModelOutput +from pilot.model.operator.model_operator import ModelOperator + + +class RequestParseOperator(MapOperator[ConversationVo, ChatContext]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: ConversationVo) -> ChatContext: + return ChatContext( + current_user_input=input_value.user_input, + model_name=input_value.model_name, + chat_session_id=input_value.conv_uid, + select_param=input_value.select_param, + chat_scene=ChatScene.ChatKnowledge, + ) + + +with DAG("simple_rag_example") as dag: + trigger_task = HttpTrigger( + "/examples/simple_rag", methods="POST", request_body=ConversationVo + ) + req_parse_task = RequestParseOperator() + prompt_task = PromptManagerOperator() + history_storage_task = ChatHistoryStorageOperator() + history_task = ChatHistoryOperator() + embedding_task = EmbeddingEngingOperator() + chat_task = BaseChatOperator() + model_task = ModelOperator() + output_parser_task = MapOperator(lambda out: out.to_dict()["text"]) + + ( + trigger_task + >> req_parse_task + >> prompt_task + >> history_storage_task + >> history_task + >> embedding_task + >> chat_task + >> model_task + >> output_parser_task + ) diff --git a/pilot/awel/__init__.py b/pilot/awel/__init__.py index 6c5313b5d..3cfc3c2bc 100644 --- a/pilot/awel/__init__.py +++ b/pilot/awel/__init__.py @@ -1,8 +1,17 @@ -"""Agentic Workflow Expression Language (AWEL)""" +"""Agentic Workflow Expression Language (AWEL) + +Note: + +AWEL is still an experimental feature and only opens the lowest level API. +The stability of this API cannot be guaranteed at present. + +""" + +from pilot.component import SystemApp from .dag.base import DAGContext, DAG -from .operator.base import BaseOperator, WorkflowRunner, initialize_awel +from .operator.base import BaseOperator, WorkflowRunner from .operator.common_operator import ( JoinOperator, ReduceStreamOperator, @@ -28,6 +37,7 @@ from .task.task_impl import ( SimpleStreamTaskOutput, _is_async_iterator, ) +from .trigger.http_trigger import HttpTrigger from .runner.local_runner import DefaultWorkflowRunner __all__ = [ @@ -57,4 +67,21 @@ __all__ = [ "StreamifyAbsOperator", "UnstreamifyAbsOperator", "TransformStreamAbsOperator", + "HttpTrigger", ] + + +def initialize_awel(system_app: SystemApp, dag_filepath: str): + from .dag.dag_manager import DAGManager + from .dag.base import DAGVar + from .trigger.trigger_manager import DefaultTriggerManager + from .operator.base import initialize_runner + + DAGVar.set_current_system_app(system_app) + + system_app.register(DefaultTriggerManager) + dag_manager = DAGManager(system_app, dag_filepath) + system_app.register_instance(dag_manager) + initialize_runner(DefaultWorkflowRunner()) + # Load all dags + dag_manager.load_dags() diff --git a/pilot/awel/base.py b/pilot/awel/base.py new file mode 100644 index 000000000..97cb8ad05 --- /dev/null +++ b/pilot/awel/base.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class Trigger(ABC): + @abstractmethod + async def trigger(self) -> None: + """Trigger the workflow or a specific operation in the workflow.""" diff --git a/pilot/awel/dag/base.py b/pilot/awel/dag/base.py index a6ad08990..ceb13c8ad 100644 --- a/pilot/awel/dag/base.py +++ b/pilot/awel/dag/base.py @@ -1,14 +1,20 @@ from abc import ABC, abstractmethod -from typing import Optional, Dict, List, Sequence, Union, Any +from typing import Optional, Dict, List, Sequence, Union, Any, Set import uuid import contextvars import threading import asyncio +import logging from collections import deque +from functools import cache +from concurrent.futures import Executor +from pilot.component import SystemApp from ..resource.base import ResourceGroup from ..task.base import TaskContext +logger = logging.getLogger(__name__) + DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] @@ -96,6 +102,8 @@ class DependencyMixin(ABC): class DAGVar: _thread_local = threading.local() _async_local = contextvars.ContextVar("current_dag_stack", default=deque()) + _system_app: SystemApp = None + _executor: Executor = None @classmethod def enter_dag(cls, dag) -> None: @@ -138,18 +146,48 @@ class DAGVar: return cls._thread_local.current_dag_stack[-1] return None + @classmethod + def get_current_system_app(cls) -> SystemApp: + if not cls._system_app: + raise RuntimeError("System APP not set for DAGVar") + return cls._system_app + + @classmethod + def set_current_system_app(cls, system_app: SystemApp) -> None: + if cls._system_app: + logger.warn("System APP has already set, nothing to do") + else: + cls._system_app = system_app + + @classmethod + def get_executor(cls) -> Executor: + return cls._executor + + @classmethod + def set_executor(cls, executor: Executor) -> None: + cls._executor = executor + class DAGNode(DependencyMixin, ABC): resource_group: Optional[ResourceGroup] = None """The resource group of current DAGNode""" def __init__( - self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None + self, + dag: Optional["DAG"] = None, + node_id: Optional[str] = None, + node_name: Optional[str] = None, + system_app: Optional[SystemApp] = None, + executor: Optional[Executor] = None, ) -> None: super().__init__() self._upstream: List["DAGNode"] = [] self._downstream: List["DAGNode"] = [] self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag() + self._system_app: Optional[SystemApp] = ( + system_app or DAGVar.get_current_system_app() + ) + self._executor: Optional[Executor] = executor or DAGVar.get_executor() if not node_id and self._dag: node_id = self._dag._new_node_id() self._node_id: str = node_id @@ -159,6 +197,10 @@ class DAGNode(DependencyMixin, ABC): def node_id(self) -> str: return self._node_id + @property + def system_app(self) -> SystemApp: + return self._system_app + def set_node_id(self, node_id: str) -> None: self._node_id = node_id @@ -178,7 +220,7 @@ class DAGNode(DependencyMixin, ABC): return self._node_name @property - def dag(self) -> "DAGNode": + def dag(self) -> "DAG": return self._dag def set_upstream(self, nodes: DependencyType) -> "DAGNode": @@ -254,17 +296,69 @@ class DAG: def __init__( self, dag_id: str, resource_group: Optional[ResourceGroup] = None ) -> None: + self._dag_id = dag_id self.node_map: Dict[str, DAGNode] = {} + self._root_nodes: Set[DAGNode] = None + self._leaf_nodes: Set[DAGNode] = None + self._trigger_nodes: Set[DAGNode] = None def _append_node(self, node: DAGNode) -> None: self.node_map[node.node_id] = node + # clear cached nodes + self._root_nodes = None + self._leaf_nodes = None def _new_node_id(self) -> str: return str(uuid.uuid4()) + @property + def dag_id(self) -> str: + return self._dag_id + + def _build(self) -> None: + from ..operator.common_operator import TriggerOperator + + nodes = set() + for _, node in self.node_map.items(): + nodes = nodes.union(_get_nodes(node)) + self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes))) + self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes))) + self._trigger_nodes = list( + set(filter(lambda x: isinstance(x, TriggerOperator), nodes)) + ) + + @property + def root_nodes(self) -> List[DAGNode]: + if not self._root_nodes: + self._build() + return self._root_nodes + + @property + def leaf_nodes(self) -> List[DAGNode]: + if not self._leaf_nodes: + self._build() + return self._leaf_nodes + + @property + def trigger_nodes(self): + if not self._trigger_nodes: + self._build() + return self._trigger_nodes + def __enter__(self): DAGVar.enter_dag(self) return self def __exit__(self, exc_type, exc_val, exc_tb): DAGVar.exit_dag() + + +def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]: + nodes = set() + if not node: + return nodes + nodes.add(node) + stream_nodes = node.upstream if is_upstream else node.downstream + for node in stream_nodes: + nodes = nodes.union(_get_nodes(node, is_upstream)) + return nodes diff --git a/pilot/awel/dag/dag_manager.py b/pilot/awel/dag/dag_manager.py new file mode 100644 index 000000000..58830e121 --- /dev/null +++ b/pilot/awel/dag/dag_manager.py @@ -0,0 +1,42 @@ +from typing import Dict, Optional +import logging +from pilot.component import BaseComponent, ComponentType, SystemApp +from .loader import DAGLoader, LocalFileDAGLoader +from .base import DAG + +logger = logging.getLogger(__name__) + + +class DAGManager(BaseComponent): + name = ComponentType.AWEL_DAG_MANAGER + + def __init__(self, system_app: SystemApp, dag_filepath: str): + super().__init__(system_app) + self.dag_loader = LocalFileDAGLoader(dag_filepath) + self.system_app = system_app + self.dag_map: Dict[str, DAG] = {} + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + def load_dags(self): + dags = self.dag_loader.load_dags() + triggers = [] + for dag in dags: + dag_id = dag.dag_id + if dag_id in self.dag_map: + raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist") + triggers += dag.trigger_nodes + from ..trigger.trigger_manager import DefaultTriggerManager + + trigger_manager: DefaultTriggerManager = self.system_app.get_component( + ComponentType.AWEL_TRIGGER_MANAGER, + DefaultTriggerManager, + default_component=None, + ) + if trigger_manager: + for trigger in triggers: + trigger_manager.register_trigger(trigger) + trigger_manager.after_register() + else: + logger.warn("No trigger manager, not register dag trigger") diff --git a/pilot/awel/dag/loader.py b/pilot/awel/dag/loader.py new file mode 100644 index 000000000..2eb89f8bc --- /dev/null +++ b/pilot/awel/dag/loader.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import List +import os +import hashlib +import sys +import logging +import traceback + +from .base import DAG + +logger = logging.getLogger(__name__) + + +class DAGLoader(ABC): + @abstractmethod + def load_dags(self) -> List[DAG]: + """Load dags""" + + +class LocalFileDAGLoader(DAGLoader): + def __init__(self, filepath: str) -> None: + super().__init__() + self._filepath = filepath + + def load_dags(self) -> List[DAG]: + if not os.path.exists(self._filepath): + return [] + if os.path.isdir(self._filepath): + return _process_directory(self._filepath) + else: + return _process_file(self._filepath) + + +def _process_directory(directory: str) -> List[DAG]: + dags = [] + for file in os.listdir(directory): + if file.endswith(".py"): + filepath = os.path.join(directory, file) + dags += _process_file(filepath) + return dags + + +def _process_file(filepath) -> List[DAG]: + mods = _load_modules_from_file(filepath) + results = _process_modules(mods) + return results + + +def _load_modules_from_file(filepath: str): + import importlib + import importlib.machinery + import importlib.util + + logger.info(f"Importing {filepath}") + + org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) + path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest() + mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}" + + if mod_name in sys.modules: + del sys.modules[mod_name] + + def parse(mod_name, filepath): + try: + loader = importlib.machinery.SourceFileLoader(mod_name, filepath) + spec = importlib.util.spec_from_loader(mod_name, loader) + new_module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = new_module + loader.exec_module(new_module) + return [new_module] + except Exception as e: + msg = traceback.format_exc() + logger.error(f"Failed to import: {filepath}, error message: {msg}") + # TODO save error message + return [] + + return parse(mod_name, filepath) + + +def _process_modules(mods) -> List[DAG]: + top_level_dags = ( + (o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG) + ) + found_dags = [] + for dag, mod in top_level_dags: + try: + # TODO validate dag params + logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}") + found_dags.append(dag) + except Exception: + msg = traceback.format_exc() + logger.error(f"Failed to dag file, error message: {msg}") + return found_dags diff --git a/pilot/awel/operator/base.py b/pilot/awel/operator/base.py index b6d1a4e14..09aa87141 100644 --- a/pilot/awel/operator/base.py +++ b/pilot/awel/operator/base.py @@ -14,6 +14,13 @@ from typing import ( ) import functools from inspect import signature +from pilot.component import SystemApp, ComponentType +from pilot.utils.executor_utils import ( + ExecutorFactory, + DefaultExecutorFactory, + blocking_func_to_async, + BlockingFunction, +) from ..dag.base import DAGNode, DAGContext, DAGVar, DAG from ..task.base import ( @@ -67,6 +74,19 @@ class BaseOperatorMeta(ABCMeta): def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag() task_id: Optional[str] = kwargs.get("task_id") + system_app: Optional[SystemApp] = ( + kwargs.get("system_app") or DAGVar.get_current_system_app() + ) + executor = kwargs.get("executor") or DAGVar.get_executor() + if not executor: + if system_app: + executor = system_app.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + else: + executor = DefaultExecutorFactory().create() + DAGVar.set_executor(executor) + if not task_id and dag: task_id = dag._new_node_id() runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner @@ -80,6 +100,10 @@ class BaseOperatorMeta(ABCMeta): kwargs["task_id"] = task_id if not kwargs.get("runner"): kwargs["runner"] = runner + if not kwargs.get("system_app"): + kwargs["system_app"] = system_app + if not kwargs.get("executor"): + kwargs["executor"] = executor real_obj = func(self, *args, **kwargs) return real_obj @@ -171,7 +195,12 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): out_ctx = await self._runner.execute_workflow(self, call_data) return out_ctx.current_task_context.task_output.output_stream + async def blocking_func_to_async( + self, func: BlockingFunction, *args, **kwargs + ) -> Any: + return await blocking_func_to_async(self._executor, func, *args, **kwargs) -def initialize_awel(runner: WorkflowRunner): + +def initialize_runner(runner: WorkflowRunner): global default_runner default_runner = runner diff --git a/pilot/awel/operator/common_operator.py b/pilot/awel/operator/common_operator.py index 6d12565aa..2c0d41dde 100644 --- a/pilot/awel/operator/common_operator.py +++ b/pilot/awel/operator/common_operator.py @@ -237,3 +237,10 @@ class InputOperator(BaseOperator, Generic[OUT]): task_output = await self._input_source.read(curr_task_ctx) curr_task_ctx.set_task_output(task_output) return task_output + + +class TriggerOperator(InputOperator, Generic[OUT]): + def __init__(self, **kwargs) -> None: + from ..task.task_impl import SimpleCallDataInputSource + + super().__init__(input_source=SimpleCallDataInputSource(), **kwargs) diff --git a/pilot/awel/runner/local_runner.py b/pilot/awel/runner/local_runner.py index 769223212..6f8a0a484 100644 --- a/pilot/awel/runner/local_runner.py +++ b/pilot/awel/runner/local_runner.py @@ -3,7 +3,7 @@ import logging from ..dag.base import DAGContext from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA -from ..operator.common_operator import BranchOperator, JoinOperator +from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator from ..task.base import TaskContext, TaskState from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput from .job_manager import JobManager @@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner): node_outputs[node.node_id] = task_ctx return try: - logger.info( + logger.debug( f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" ) await node._run(dag_ctx) @@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner): if isinstance(node, BranchOperator): skip_nodes = task_ctx.metadata.get("skip_node_names", []) - logger.info( + logger.debug( f"Current is branch operator, skip node names: {skip_nodes}" ) _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) diff --git a/pilot/server/componet_configs.py b/pilot/awel/trigger/__init__.py similarity index 100% rename from pilot/server/componet_configs.py rename to pilot/awel/trigger/__init__.py diff --git a/pilot/awel/trigger/base.py b/pilot/awel/trigger/base.py new file mode 100644 index 000000000..9cb5d1895 --- /dev/null +++ b/pilot/awel/trigger/base.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from abc import ABC, abstractmethod + +from ..operator.base import BaseOperator +from ..operator.common_operator import TriggerOperator +from ..dag.base import DAGContext +from ..task.base import TaskOutput + + +class Trigger(TriggerOperator, ABC): + @abstractmethod + async def trigger(self, end_operator: "BaseOperator") -> None: + """Trigger the workflow or a specific operation in the workflow.""" diff --git a/pilot/awel/trigger/http_trigger.py b/pilot/awel/trigger/http_trigger.py new file mode 100644 index 000000000..de459c066 --- /dev/null +++ b/pilot/awel/trigger/http_trigger.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict +from starlette.requests import Request +from starlette.responses import Response +from pydantic import BaseModel +import logging + +from .base import Trigger +from ..operator.base import BaseOperator + +if TYPE_CHECKING: + from fastapi import APIRouter, FastAPI + +RequestBody = Union[Request, Type[BaseModel], str] + +logger = logging.getLogger(__name__) + + +class HttpTrigger(Trigger): + def __init__( + self, + endpoint: str, + methods: Optional[Union[str, List[str]]] = "GET", + request_body: Optional[RequestBody] = None, + streaming_response: Optional[bool] = False, + response_model: Optional[Type] = None, + response_headers: Optional[Dict[str, str]] = None, + response_media_type: Optional[str] = None, + status_code: Optional[int] = 200, + router_tags: Optional[List[str]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + self._endpoint = endpoint + self._methods = methods + self._req_body = request_body + self._streaming_response = streaming_response + self._response_model = response_model + self._status_code = status_code + self._router_tags = router_tags + self._response_headers = response_headers + self._response_media_type = response_media_type + self._end_node: BaseOperator = None + + async def trigger(self) -> None: + pass + + def mount_to_router(self, router: "APIRouter") -> None: + from fastapi import Depends + from fastapi.responses import StreamingResponse + + methods = self._methods if isinstance(self._methods, list) else [self._methods] + + def create_route_function(name): + async def _request_body_dependency(request: Request): + return await _parse_request_body(request, self._req_body) + + async def route_function(body: Any = Depends(_request_body_dependency)): + end_node = self.dag.leaf_nodes + if len(end_node) != 1: + raise ValueError("HttpTrigger just support one leaf node in dag") + end_node = end_node[0] + if not self._streaming_response: + return await end_node.call(call_data={"data": body}) + else: + headers = self._response_headers + media_type = ( + self._response_media_type + if self._response_media_type + else "text/event-stream" + ) + if not headers: + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + return StreamingResponse( + end_node.call_stream(call_data={"data": body}), + headers=headers, + media_type=media_type, + ) + + route_function.__name__ = name + return route_function + + function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}" + dynamic_route_function = create_route_function(function_name) + logger.info( + f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}" + ) + + router.api_route( + self._endpoint, + methods=methods, + response_model=self._response_model, + status_code=self._status_code, + tags=self._router_tags, + )(dynamic_route_function) + + +async def _parse_request_body( + request: Request, request_body_cls: Optional[Type[BaseModel]] +): + if not request_body_cls: + return None + if request.method == "POST": + json_data = await request.json() + return request_body_cls(**json_data) + elif request.method == "GET": + return request_body_cls(**request.query_params) + else: + return request diff --git a/pilot/awel/trigger/trigger_manager.py b/pilot/awel/trigger/trigger_manager.py new file mode 100644 index 000000000..feb674ffb --- /dev/null +++ b/pilot/awel/trigger/trigger_manager.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod +from typing import Any, TYPE_CHECKING, Optional +import logging + +if TYPE_CHECKING: + from fastapi import APIRouter + +from pilot.component import SystemApp, BaseComponent, ComponentType + +logger = logging.getLogger(__name__) + + +class TriggerManager(ABC): + @abstractmethod + def register_trigger(self, trigger: Any) -> None: + """ "Register a trigger to current manager""" + + +class HttpTriggerManager(TriggerManager): + def __init__( + self, + router: Optional["APIRouter"] = None, + router_prefix: Optional[str] = "/api/v1/awel/trigger", + ) -> None: + if not router: + from fastapi import APIRouter + + router = APIRouter() + self._router_prefix = router_prefix + self._router = router + self._trigger_map = {} + + def register_trigger(self, trigger: Any) -> None: + from .http_trigger import HttpTrigger + + if not isinstance(trigger, HttpTrigger): + raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") + trigger: HttpTrigger = trigger + trigger_id = trigger.node_id + if trigger_id not in self._trigger_map: + trigger.mount_to_router(self._router) + self._trigger_map[trigger_id] = trigger + + def _init_app(self, system_app: SystemApp): + logger.info( + f"Include router {self._router} to prefix path {self._router_prefix}" + ) + system_app.app.include_router( + self._router, prefix=self._router_prefix, tags=["AWEL"] + ) + + +class DefaultTriggerManager(TriggerManager, BaseComponent): + name = ComponentType.AWEL_TRIGGER_MANAGER + + def __init__(self, system_app: SystemApp | None = None): + self.system_app = system_app + self.http_trigger = HttpTriggerManager() + super().__init__(None) + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + def register_trigger(self, trigger: Any) -> None: + from .http_trigger import HttpTrigger + + if isinstance(trigger, HttpTrigger): + logger.info(f"Register trigger {trigger}") + self.http_trigger.register_trigger(trigger) + else: + raise ValueError(f"Unsupport trigger: {trigger}") + + def after_register(self) -> None: + self.http_trigger._init_app(self.system_app) diff --git a/pilot/component.py b/pilot/component.py index d79a8d395..891ba7ad9 100644 --- a/pilot/component.py +++ b/pilot/component.py @@ -54,6 +54,8 @@ class ComponentType(str, Enum): TRACER = "dbgpt_tracer" TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage" RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default" + AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager" + AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager" class BaseComponent(LifeCycle, ABC): diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index fec343f2a..356abb644 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -16,6 +16,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") FONT_DIR = os.path.join(PILOT_PATH, "fonts") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") +_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel") current_directory = os.getcwd() diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py index 8afbaf06b..f1c25d633 100644 --- a/pilot/memory/chat_history/store_type/meta_db_history.py +++ b/pilot/memory/chat_history/store_type/meta_db_history.py @@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory): logger.error("init create conversation log error!" + str(e)) def append(self, once_message: OnceConversation) -> None: - logger.info(f"db history append: {once_message}") + logger.debug(f"db history append: {once_message}") chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( self.chat_seesion_id ) diff --git a/pilot/model/operator/model_operator.py b/pilot/model/operator/model_operator.py index 2f051377a..d8ee62172 100644 --- a/pilot/model/operator/model_operator.py +++ b/pilot/model/operator/model_operator.py @@ -7,9 +7,10 @@ from pilot.awel import ( MapOperator, TransformStreamAbsOperator, ) +from pilot.component import ComponentType from pilot.awel.operator.base import BaseOperator from pilot.model.base import ModelOutput -from pilot.model.cluster import WorkerManager +from pilot.model.cluster import WorkerManager, WorkerManagerFactory from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): streamify: Asynchronously processes a stream of inputs, yielding model outputs. """ - def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: + def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None: super().__init__(**kwargs) self.worker_manager = worker_manager @@ -42,6 +43,10 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): Returns: AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. """ + if not self.worker_manager: + self.worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() async for out in self.worker_manager.generate_stream(input_value): yield out @@ -57,9 +62,9 @@ class ModelOperator(MapOperator[Dict, ModelOutput]): map: Asynchronously processes a single input and returns the model output. """ - def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: - self.worker_manager = worker_manager + def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None: super().__init__(**kwargs) + self.worker_manager = worker_manager async def map(self, input_value: Dict) -> ModelOutput: """Process a single input and return the model output. @@ -70,6 +75,10 @@ class ModelOperator(MapOperator[Dict, ModelOutput]): Returns: ModelOutput: The output from the model. """ + if not self.worker_manager: + self.worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() return await self.worker_manager.generate(input_value) diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index 9e6d1a20a..1da815bfa 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params): proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" payloads["model"] = proxyllm_backend - logger.info( - f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}" - ) + logger.info(f"Send request to real model {proxyllm_backend}") return history, payloads diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index e5ec0e8b8..0e263f7e5 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -68,7 +68,7 @@ class BaseChat(ABC): CFG.prompt_template_registry.get_prompt_template( self.chat_mode.value(), language=CFG.LANGUAGE, - model_name=CFG.LLM_MODEL, + model_name=self.llm_model, proxyllm_backend=CFG.PROXYLLM_BACKEND, ) ) @@ -141,13 +141,7 @@ class BaseChat(ABC): return speak_to_user async def __call_base(self): - import inspect - - input_values = ( - await self.generate_input_values() - if inspect.isawaitable(self.generate_input_values()) - else self.generate_input_values() - ) + input_values = await self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) @@ -379,16 +373,18 @@ class BaseChat(ABC): if self.prompt_template.template_define: text += self.prompt_template.template_define + self.prompt_template.sep ### Load prompt - text += self.__load_system_message() + text += _load_system_message(self.current_message, self.prompt_template) ### Load examples - text += self.__load_example_messages() + text += _load_example_messages(self.prompt_template) ### Load History - text += self.__load_history_messages() + text += _load_history_messages( + self.prompt_template, self.history_message, self.chat_retention_rounds + ) ### Load User Input - text += self.__load_user_message() + text += _load_user_message(self.current_message, self.prompt_template) return text def generate_llm_messages(self) -> List[ModelMessage]: @@ -406,137 +402,26 @@ class BaseChat(ABC): ) ) ### Load prompt - messages += self.__load_system_message(str_message=False) + messages += _load_system_message( + self.current_message, self.prompt_template, str_message=False + ) ### Load examples - messages += self.__load_example_messages(str_message=False) + messages += _load_example_messages(self.prompt_template, str_message=False) ### Load History - messages += self.__load_history_messages(str_message=False) + messages += _load_history_messages( + self.prompt_template, + self.history_message, + self.chat_retention_rounds, + str_message=False, + ) ### Load User Input - messages += self.__load_user_message(str_message=False) + messages += _load_user_message( + self.current_message, self.prompt_template, str_message=False + ) return messages - def __load_system_message(self, str_message: bool = True): - system_convs = self.current_message.get_system_conv() - system_text = "" - system_messages = [] - for system_conv in system_convs: - system_text += ( - system_conv.type + ":" + system_conv.content + self.prompt_template.sep - ) - system_messages.append( - ModelMessage(role=system_conv.type, content=system_conv.content) - ) - return system_text if str_message else system_messages - - def __load_user_message(self, str_message: bool = True): - user_conv = self.current_message.get_user_conv() - user_messages = [] - if user_conv: - user_text = ( - user_conv.type + ":" + user_conv.content + self.prompt_template.sep - ) - user_messages.append( - ModelMessage(role=user_conv.type, content=user_conv.content) - ) - return user_text if str_message else user_messages - else: - raise ValueError("Hi! What do you want to talk about?") - - def __load_example_messages(self, str_message: bool = True): - example_text = "" - example_messages = [] - if self.prompt_template.example_selector: - for round_conv in self.prompt_template.example_selector.examples(): - for round_message in round_conv["messages"]: - if not round_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = round_message["type"] - message_content = round_message["data"]["content"] - example_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep - ) - example_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - return example_text if str_message else example_messages - - def __load_history_messages(self, str_message: bool = True): - history_text = "" - history_messages = [] - if self.prompt_template.need_historical_messages: - if self.history_message: - logger.info( - f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" - ) - if len(self.history_message) > self.chat_retention_rounds: - for first_message in self.history_message[0]["messages"]: - if not first_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = first_message["type"] - message_content = first_message["data"]["content"] - history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep - ) - history_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - if self.chat_retention_rounds > 1: - index = self.chat_retention_rounds - 1 - for round_conv in self.history_message[-index:]: - for round_message in round_conv["messages"]: - if not round_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = round_message["type"] - message_content = round_message["data"]["content"] - history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep - ) - history_messages.append( - ModelMessage( - role=message_type, content=message_content - ) - ) - - else: - ### user all history - for conversation in self.history_message: - for message in conversation["messages"]: - ### histroy message not have promot and view info - if not message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = message["type"] - message_content = message["data"]["content"] - history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep - ) - history_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - - return history_text if str_message else history_messages - def current_ai_response(self) -> str: for message in self.current_message.messages: if message.type == "view": @@ -656,3 +541,127 @@ def _build_model_operator( cache_check_branch_node >> cached_node >> join_node return join_node + + +def _load_system_message( + current_message: OnceConversation, + prompt_template: PromptTemplate, + str_message: bool = True, +): + system_convs = current_message.get_system_conv() + system_text = "" + system_messages = [] + for system_conv in system_convs: + system_text += ( + system_conv.type + ":" + system_conv.content + prompt_template.sep + ) + system_messages.append( + ModelMessage(role=system_conv.type, content=system_conv.content) + ) + return system_text if str_message else system_messages + + +def _load_user_message( + current_message: OnceConversation, + prompt_template: PromptTemplate, + str_message: bool = True, +): + user_conv = current_message.get_user_conv() + user_messages = [] + if user_conv: + user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep + user_messages.append( + ModelMessage(role=user_conv.type, content=user_conv.content) + ) + return user_text if str_message else user_messages + else: + raise ValueError("Hi! What do you want to talk about?") + + +def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True): + example_text = "" + example_messages = [] + if prompt_template.example_selector: + for round_conv in prompt_template.example_selector.examples(): + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + ModelMessageRoleType.VIEW, + ModelMessageRoleType.SYSTEM, + ]: + message_type = round_message["type"] + message_content = round_message["data"]["content"] + example_text += ( + message_type + ":" + message_content + prompt_template.sep + ) + example_messages.append( + ModelMessage(role=message_type, content=message_content) + ) + return example_text if str_message else example_messages + + +def _load_history_messages( + prompt_template: PromptTemplate, + history_message: List[OnceConversation], + chat_retention_rounds: int, + str_message: bool = True, +): + history_text = "" + history_messages = [] + if prompt_template.need_historical_messages: + if history_message: + logger.info( + f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!" + ) + if len(history_message) > chat_retention_rounds: + for first_message in history_message[0]["messages"]: + if not first_message["type"] in [ + ModelMessageRoleType.VIEW, + ModelMessageRoleType.SYSTEM, + ]: + message_type = first_message["type"] + message_content = first_message["data"]["content"] + history_text += ( + message_type + ":" + message_content + prompt_template.sep + ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) + if chat_retention_rounds > 1: + index = chat_retention_rounds - 1 + for round_conv in history_message[-index:]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + ModelMessageRoleType.VIEW, + ModelMessageRoleType.SYSTEM, + ]: + message_type = round_message["type"] + message_content = round_message["data"]["content"] + history_text += ( + message_type + + ":" + + message_content + + prompt_template.sep + ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) + + else: + ### user all history + for conversation in history_message: + for message in conversation["messages"]: + ### histroy message not have promot and view info + if not message["type"] in [ + ModelMessageRoleType.VIEW, + ModelMessageRoleType.SYSTEM, + ]: + message_type = message["type"] + message_content = message["data"]["content"] + history_text += ( + message_type + ":" + message_content + prompt_template.sep + ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) + + return history_text if str_message else history_messages diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 12a72e909..bca03acf1 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -117,6 +117,10 @@ class ModelMessage(BaseModel): def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: return list(map(lambda m: m.dict(), messages)) + @staticmethod + def build_human_message(content: str) -> "ModelMessage": + return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) + class Generation(BaseModel): """Output of a single generation.""" diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index 6aa1d3d91..00cb27a2b 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -6,7 +6,6 @@ import re import sqlparse import pandas as pd import chardet -import pandas as pd import numpy as np from pyparsing import ( CaselessKeyword, @@ -27,6 +26,8 @@ from pyparsing import ( from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.string_utils import is_chinese_include_number +logger = logging.getLogger(__name__) + def excel_colunm_format(old_name: str) -> str: new_column = old_name.strip() @@ -263,7 +264,7 @@ class ExcelReader: file_name = os.path.basename(file_path) self.file_name_without_extension = os.path.splitext(file_name)[0] encoding, confidence = detect_encoding(file_path) - logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})") + logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})") self.excel_file_name = file_name self.extension = os.path.splitext(file_name)[1] # read excel file @@ -323,7 +324,7 @@ class ExcelReader: colunms.append(descrip[0]) return colunms, results.fetchall() except Exception as e: - logging.error("excel sql run error!", e) + logger.error(f"excel sql run error!, {str(e)}") raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}") def get_df_by_sql_ex(self, sql): diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 1cd5765da..bd1dd9de8 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text): clean_str = super().parse_prompt_response(model_out_text) - logging.info("clean prompt response:", clean_str) + logger.info(f"clean prompt response: {clean_str}") # Compatible with community pure sql output model if self.is_sql_statement(clean_str): return SqlAction(clean_str, "") @@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser): thoughts = response[key] return SqlAction(sql, thoughts) except Exception as e: - logging.error("json load faild") + logger.error("json load faild") return SqlAction("", clean_str) def parse_view_response(self, speak, data, prompt_response) -> str: diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py index bb52961b5..373bb4e5d 100644 --- a/pilot/scene/chat_knowledge/extract_entity/chat.py +++ b/pilot/scene/chat_knowledge/extract_entity/chat.py @@ -24,7 +24,7 @@ class ExtractEntity(BaseChat): self.user_input = chat_param["current_user_input"] self.extract_mode = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { "text": self.user_input, } diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py index 11fe871ab..28152b92e 100644 --- a/pilot/scene/chat_knowledge/extract_triplet/chat.py +++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py @@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat): self.user_input = chat_param["current_user_input"] self.extract_mode = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { "text": self.user_input, } diff --git a/pilot/scene/chat_knowledge/refine_summary/chat.py b/pilot/scene/chat_knowledge/refine_summary/chat.py index a257332ae..2f3181d5e 100644 --- a/pilot/scene/chat_knowledge/refine_summary/chat.py +++ b/pilot/scene/chat_knowledge/refine_summary/chat.py @@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat): self.existing_answer = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { # "context": self.user_input, "existing_answer": self.existing_answer, diff --git a/pilot/scene/chat_knowledge/summary/chat.py b/pilot/scene/chat_knowledge/summary/chat.py index 7327b7a5b..be4ee00c3 100644 --- a/pilot/scene/chat_knowledge/summary/chat.py +++ b/pilot/scene/chat_knowledge/summary/chat.py @@ -23,7 +23,7 @@ class ExtractSummary(BaseChat): self.user_input = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { "context": self.user_input, } diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 551fe2d36..a0c15e658 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat): self.current_user_input, self.top_k, ) - self.sources = self.merge_by_key( + self.sources = _merge_by_key( list(map(lambda doc: doc.metadata, docs)), "source" ) @@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat): ) return html - def merge_by_key(self, data, key): - result = {} - for item in data: - if item.get(key): - item_key = os.path.basename(item.get(key)) - if item_key in result: - if "pages" in result[item_key] and "page" in item: - result[item_key]["pages"].append(str(item["page"])) - elif "page" in item: - result[item_key]["pages"] = [ - result[item_key]["pages"], - str(item["page"]), - ] - else: - if "page" in item: - result[item_key] = { - "source": item_key, - "pages": [str(item["page"])], - } - else: - result[item_key] = {"source": item_key} - return list(result.values()) - @property def chat_type(self) -> str: return ChatScene.ChatKnowledge.value() @@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat): def get_space_context(self, space_name): service = KnowledgeService() return service.get_space_context(space_name) + + +def _merge_by_key(data, key): + result = {} + for item in data: + if item.get(key): + item_key = os.path.basename(item.get(key)) + if item_key in result: + if "pages" in result[item_key] and "page" in item: + result[item_key]["pages"].append(str(item["page"])) + elif "page" in item: + result[item_key]["pages"] = [ + result[item_key]["pages"], + str(item["page"]), + ] + else: + if "page" in item: + result[item_key] = { + "source": item_key, + "pages": [str(item["page"])], + } + else: + result[item_key] = {"source": item_key} + return list(result.values()) diff --git a/pilot/scene/operator/__init__.py b/pilot/scene/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/operator/_experimental.py b/pilot/scene/operator/_experimental.py new file mode 100644 index 000000000..f0ee06179 --- /dev/null +++ b/pilot/scene/operator/_experimental.py @@ -0,0 +1,255 @@ +from typing import Dict, Optional, List, Any +from dataclasses import dataclass +import datetime +import os +from pilot.awel import MapOperator +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.message import OnceConversation +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + + +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory + +# TODO move global config +CFG = Config() + + +@dataclass +class ChatContext: + current_user_input: str + model_name: Optional[str] + chat_session_id: Optional[str] = None + select_param: Optional[str] = None + chat_scene: Optional[ChatScene] = ChatScene.ChatNormal + prompt_template: Optional[PromptTemplate] = None + chat_retention_rounds: Optional[int] = 0 + history_storage: Optional[BaseChatHistoryMemory] = None + history_manager: Optional["ChatHistoryManager"] = None + # The input values for prompt template + input_values: Optional[Dict] = None + echo: Optional[bool] = False + + def build_model_payload(self) -> Dict: + if not self.input_values: + raise ValueError("The input value can't be empty") + llm_messages = self.history_manager._new_chat(self.input_values) + return { + "model": self.model_name, + "prompt": "", + "messages": llm_messages, + "temperature": float(self.prompt_template.temperature), + "max_new_tokens": int(self.prompt_template.max_new_tokens), + "echo": self.echo, + } + + +class ChatHistoryManager: + def __init__( + self, + chat_ctx: ChatContext, + prompt_template: PromptTemplate, + history_storage: BaseChatHistoryMemory, + chat_retention_rounds: Optional[int] = 0, + ) -> None: + self._chat_ctx = chat_ctx + self.chat_retention_rounds = chat_retention_rounds + self.current_message: OnceConversation = OnceConversation( + chat_ctx.chat_scene.value() + ) + self.prompt_template = prompt_template + self.history_storage: BaseChatHistoryMemory = history_storage + self.history_message: List[OnceConversation] = history_storage.messages() + self.current_message.model_name = chat_ctx.model_name + if chat_ctx.select_param: + if len(chat_ctx.chat_scene.param_types()) > 0: + self.current_message.param_type = chat_ctx.chat_scene.param_types()[0] + self.current_message.param_value = chat_ctx.select_param + + def _new_chat(self, input_values: Dict) -> List[ModelMessage]: + self.current_message.chat_order = len(self.history_message) + 1 + self.current_message.add_user_message(self._chat_ctx.current_user_input) + self.current_message.start_date = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + self.current_message.tokens = 0 + if self.prompt_template.template: + current_prompt = self.prompt_template.format(**input_values) + self.current_message.add_system_message(current_prompt) + return self._generate_llm_messages() + + def _generate_llm_messages(self) -> List[ModelMessage]: + from pilot.scene.base_chat import ( + _load_system_message, + _load_example_messages, + _load_history_messages, + _load_user_message, + ) + + messages = [] + ### Load scene setting or character definition as system message + if self.prompt_template.template_define: + messages.append( + ModelMessage( + role=ModelMessageRoleType.SYSTEM, + content=self.prompt_template.template_define, + ) + ) + ### Load prompt + messages += _load_system_message( + self.current_message, self.prompt_template, str_message=False + ) + ### Load examples + messages += _load_example_messages(self.prompt_template, str_message=False) + + ### Load History + messages += _load_history_messages( + self.prompt_template, + self.history_message, + self.chat_retention_rounds, + str_message=False, + ) + + ### Load User Input + messages += _load_user_message( + self.current_message, self.prompt_template, str_message=False + ) + return messages + + +class PromptManagerOperator(MapOperator[ChatContext, ChatContext]): + def __init__(self, prompt_template: PromptTemplate = None, **kwargs): + super().__init__(**kwargs) + self._prompt_template = prompt_template + + async def map(self, input_value: ChatContext) -> ChatContext: + if not self._prompt_template: + self._prompt_template: PromptTemplate = ( + CFG.prompt_template_registry.get_prompt_template( + input_value.chat_scene.value(), + language=CFG.LANGUAGE, + model_name=input_value.model_name, + proxyllm_backend=CFG.PROXYLLM_BACKEND, + ) + ) + input_value.prompt_template = self._prompt_template + return input_value + + +class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]): + def __init__(self, history: BaseChatHistoryMemory = None, **kwargs): + super().__init__(**kwargs) + self._history = history + + async def map(self, input_value: ChatContext) -> ChatContext: + if self._history: + return self._history + chat_history_fac = ChatHistory() + input_value.history_storage = chat_history_fac.get_store_instance( + input_value.chat_session_id + ) + return input_value + + +class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]): + def __init__(self, history: BaseChatHistoryMemory = None, **kwargs): + super().__init__(**kwargs) + self._history = history + + async def map(self, input_value: ChatContext) -> ChatContext: + history_storage = self._history or input_value.history_storage + if not history_storage: + from pilot.memory.chat_history.store_type.mem_history import ( + MemHistoryMemory, + ) + + history_storage = MemHistoryMemory(input_value.chat_session_id) + input_value.history_storage = history_storage + input_value.history_manager = ChatHistoryManager( + input_value, + input_value.prompt_template, + history_storage, + input_value.chat_retention_rounds, + ) + return input_value + + +class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: ChatContext) -> ChatContext: + from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG + from pilot.embedding_engine.embedding_engine import EmbeddingEngine + from pilot.embedding_engine.embedding_factory import EmbeddingFactory + from pilot.scene.chat_knowledge.v1.chat import _merge_by_key + + # TODO, decompose the current operator into some atomic operators + knowledge_space = input_value.select_param + vector_store_config = { + "vector_store_name": knowledge_space, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + } + embedding_factory = self.system_app.get_component( + "embedding_factory", EmbeddingFactory + ) + knowledge_embedding_client = EmbeddingEngine( + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config=vector_store_config, + embedding_factory=embedding_factory, + ) + space_context = await self._get_space_context(knowledge_space) + top_k = ( + CFG.KNOWLEDGE_SEARCH_TOP_SIZE + if space_context is None + else int(space_context["embedding"]["topk"]) + ) + max_token = ( + CFG.KNOWLEDGE_SEARCH_MAX_TOKEN + if space_context is None or space_context.get("prompt") is None + else int(space_context["prompt"]["max_token"]) + ) + input_value.prompt_template.template_is_strict = False + if space_context and space_context.get("prompt"): + input_value.prompt_template.template_define = space_context["prompt"][ + "scene" + ] + input_value.prompt_template.template = space_context["prompt"]["template"] + + docs = await self.blocking_func_to_async( + knowledge_embedding_client.similar_search, + input_value.current_user_input, + top_k, + ) + sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source") + if not docs or len(docs) == 0: + print("no relevant docs to retrieve") + context = "no relevant docs to retrieve" + else: + context = [d.page_content for d in docs] + context = context[:max_token] + relations = list( + set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs]) + ) + input_value.input_values = { + "context": context, + "question": input_value.current_user_input, + "relations": relations, + } + return input_value + + async def _get_space_context(self, space_name): + from pilot.server.knowledge.service import KnowledgeService + + service = KnowledgeService() + return await self.blocking_func_to_async(service.get_space_context, space_name) + + +class BaseChatOperator(MapOperator[ChatContext, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: ChatContext) -> Dict: + return input_value.build_model_payload() diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index 58269385b..4b7f1bb2b 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -45,6 +45,7 @@ def initialize_components( param, system_app, embedding_model_name, embedding_model_path ) _initialize_model_cache(system_app) + _initialize_awel(system_app) def _initialize_embedding_model( @@ -149,3 +150,10 @@ def _initialize_model_cache(system_app: SystemApp): max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) + + +def _initialize_awel(system_app: SystemApp): + from pilot.awel import initialize_awel + from pilot.configs.model_config import _DAG_DEFINITION_DIR + + initialize_awel(system_app, _DAG_DEFINITION_DIR) diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index e76904203..c44caf8b4 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -1,9 +1,14 @@ import argparse import os from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass -from typing import Any, List, Optional, Type, Union, Callable, Dict +from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING from collections import OrderedDict +if TYPE_CHECKING: + from pydantic import BaseModel + +MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__" + @dataclass class ParameterDescription: @@ -613,6 +618,64 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]: return default_value +def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]: + import pydantic + + version = int(pydantic.VERSION.split(".")[0]) + schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema() + required_fields = set(schema.get("required", [])) + param_descs = [] + for field_name, field_schema in schema.get("properties", {}).items(): + field = model_cls.model_fields[field_name] + param_type = field_schema.get("type") + if not param_type and "anyOf" in field_schema: + for any_of in field_schema["anyOf"]: + if any_of["type"] != "null": + param_type = any_of["type"] + break + if version >= 2: + default_value = ( + field.default + if hasattr(field, "default") + and str(field.default) != "PydanticUndefined" + else None + ) + else: + default_value = ( + field.default + if not field.allow_none + else ( + field.default_factory() if callable(field.default_factory) else None + ) + ) + description = field_schema.get("description", "") + is_required = field_name in required_fields + valid_values = None + ext_metadata = None + if hasattr(field, "field_info"): + valid_values = ( + list(field.field_info.choices) + if hasattr(field.field_info, "choices") + else None + ) + ext_metadata = ( + field.field_info.extra if hasattr(field.field_info, "extra") else None + ) + param_class = (f"{model_cls.__module__}.{model_cls.__name__}",) + param_desc = ParameterDescription( + param_class=param_class, + param_name=field_name, + param_type=param_type, + default_value=default_value, + description=description, + required=is_required, + valid_values=valid_values, + ext_metadata=ext_metadata, + ) + param_descs.append(param_desc) + return param_descs + + class _SimpleArgParser: def __init__(self, *args): self.params = {arg.replace("_", "-"): None for arg in args}