feat(awel): AWEL supports http trigger and add some AWEL examples (#815)

- AWEL supports http trigger.
- Disassemble the KBQA into some atomic operator.
- Add some AWEL examples.
This commit is contained in:
Aries-ckt 2023-11-21 15:31:28 +08:00 committed by GitHub
commit 2ce77519de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1199 additions and 187 deletions

View File

@ -0,0 +1,54 @@
"""AWEL: Simple chat dag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"user_input": "hello"
}'
"""
from typing import Dict
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
from pilot.scene.base_message import ModelMessage
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class TriggerReqBody(BaseModel):
model: str = Field(..., description="Model name")
user_input: str = Field(..., description="User input")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
hist = []
hist.append(ModelMessage.build_human_message(input_value.user_input))
hist = list(h.dict() for h in hist)
params = {
"prompt": input_value.user_input,
"messages": hist,
"model": input_value.model,
"echo": False,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_dag_example") as dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_chat", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = ModelOperator()
# type(out) == ModelOutput
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task

View File

@ -0,0 +1,32 @@
"""AWEL: Simple dag example
Example:
.. code-block:: shell
curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan
"""
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
class TriggerReqBody(BaseModel):
name: str = Field(..., description="User name")
age: int = Field(18, description="User age")
class RequestHandleOperator(MapOperator[TriggerReqBody, str]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> str:
print(f"Receive input value: {input_value}")
return f"Hello, {input_value.name}, your age is {input_value.age}"
with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator()
trigger >> map_node

View File

@ -0,0 +1,70 @@
"""AWEL: Simple rag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
-H "Content-Type: application/json" -d '{
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
"model_name": "proxyllm",
"chat_mode": "chat_knowledge",
"user_input": "What is DB-GPT?",
"select_param": "default"
}'
"""
from pilot.awel import HttpTrigger, DAG, MapOperator
from pilot.scene.operator._experimental import (
ChatContext,
PromptManagerOperator,
ChatHistoryStorageOperator,
ChatHistoryOperator,
EmbeddingEngingOperator,
BaseChatOperator,
)
from pilot.scene.base import ChatScene
from pilot.openapi.api_view_model import ConversationVo
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ConversationVo) -> ChatContext:
return ChatContext(
current_user_input=input_value.user_input,
model_name=input_value.model_name,
chat_session_id=input_value.conv_uid,
select_param=input_value.select_param,
chat_scene=ChatScene.ChatKnowledge,
)
with DAG("simple_rag_example") as dag:
trigger_task = HttpTrigger(
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
embedding_task = EmbeddingEngingOperator()
chat_task = BaseChatOperator()
model_task = ModelOperator()
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
(
trigger_task
>> req_parse_task
>> prompt_task
>> history_storage_task
>> history_task
>> embedding_task
>> chat_task
>> model_task
>> output_parser_task
)

View File

@ -1,8 +1,17 @@
"""Agentic Workflow Expression Language (AWEL)""" """Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""
from pilot.component import SystemApp
from .dag.base import DAGContext, DAG from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import ( from .operator.common_operator import (
JoinOperator, JoinOperator,
ReduceStreamOperator, ReduceStreamOperator,
@ -28,6 +37,7 @@ from .task.task_impl import (
SimpleStreamTaskOutput, SimpleStreamTaskOutput,
_is_async_iterator, _is_async_iterator,
) )
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner from .runner.local_runner import DefaultWorkflowRunner
__all__ = [ __all__ = [
@ -57,4 +67,21 @@ __all__ = [
"StreamifyAbsOperator", "StreamifyAbsOperator",
"UnstreamifyAbsOperator", "UnstreamifyAbsOperator",
"TransformStreamAbsOperator", "TransformStreamAbsOperator",
"HttpTrigger",
] ]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()

7
pilot/awel/base.py Normal file
View File

@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -1,14 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid import uuid
import contextvars import contextvars
import threading import threading
import asyncio import asyncio
import logging
from collections import deque from collections import deque
from functools import cache
from concurrent.futures import Executor
from pilot.component import SystemApp
from ..resource.base import ResourceGroup from ..resource.base import ResourceGroup
from ..task.base import TaskContext from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
@ -96,6 +102,8 @@ class DependencyMixin(ABC):
class DAGVar: class DAGVar:
_thread_local = threading.local() _thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque()) _async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
@classmethod @classmethod
def enter_dag(cls, dag) -> None: def enter_dag(cls, dag) -> None:
@ -138,18 +146,48 @@ class DAGVar:
return cls._thread_local.current_dag_stack[-1] return cls._thread_local.current_dag_stack[-1]
return None return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
if not cls._system_app:
raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC): class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode""" """The resource group of current DAGNode"""
def __init__( def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self._upstream: List["DAGNode"] = [] self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = [] self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag() self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag: if not node_id and self._dag:
node_id = self._dag._new_node_id() node_id = self._dag._new_node_id()
self._node_id: str = node_id self._node_id: str = node_id
@ -159,6 +197,10 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str: def node_id(self) -> str:
return self._node_id return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None: def set_node_id(self, node_id: str) -> None:
self._node_id = node_id self._node_id = node_id
@ -178,7 +220,7 @@ class DAGNode(DependencyMixin, ABC):
return self._node_name return self._node_name
@property @property
def dag(self) -> "DAGNode": def dag(self) -> "DAG":
return self._dag return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode": def set_upstream(self, nodes: DependencyType) -> "DAGNode":
@ -254,17 +296,69 @@ class DAG:
def __init__( def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None: ) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {} self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None: def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str: def _new_node_id(self) -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self): def __enter__(self):
DAGVar.enter_dag(self) DAGVar.enter_dag(self)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag() DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from pilot.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

93
pilot/awel/dag/loader.py Normal file
View File

@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

@ -14,6 +14,13 @@ from typing import (
) )
import functools import functools
from inspect import signature from inspect import signature
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import ( from ..task.base import (
@ -67,6 +74,19 @@ class BaseOperatorMeta(ABCMeta):
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag() dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id") task_id: Optional[str] = kwargs.get("task_id")
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not task_id and dag: if not task_id and dag:
task_id = dag._new_node_id() task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
@ -80,6 +100,10 @@ class BaseOperatorMeta(ABCMeta):
kwargs["task_id"] = task_id kwargs["task_id"] = task_id
if not kwargs.get("runner"): if not kwargs.get("runner"):
kwargs["runner"] = runner kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs) real_obj = func(self, *args, **kwargs)
return real_obj return real_obj
@ -171,7 +195,12 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
out_ctx = await self._runner.execute_workflow(self, call_data) out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream return out_ctx.current_task_context.task_output.output_stream
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_awel(runner: WorkflowRunner):
def initialize_runner(runner: WorkflowRunner):
global default_runner global default_runner
default_runner = runner default_runner = runner

View File

@ -237,3 +237,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
task_output = await self._input_source.read(curr_task_ctx) task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output) curr_task_ctx.set_task_output(task_output)
return task_output return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
def __init__(self, **kwargs) -> None:
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@ -3,7 +3,7 @@ import logging
from ..dag.base import DAGContext from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager from .job_manager import JobManager
@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[node.node_id] = task_ctx node_outputs[node.node_id] = task_ctx
return return
try: try:
logger.info( logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
) )
await node._run(dag_ctx) await node._run(dag_ctx)
@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
if isinstance(node, BranchOperator): if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", []) skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info( logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}" f"Current is branch operator, skip node names: {skip_nodes}"
) )
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -0,0 +1,117 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from starlette.requests import Request
from starlette.responses import Response
from pydantic import BaseModel
import logging
from .base import Trigger
from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self._endpoint = endpoint
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
async def trigger(self) -> None:
pass
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
route_function.__name__ = name
return route_function
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
self._endpoint,
methods=methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
return request

View File

@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
if TYPE_CHECKING:
from fastapi import APIRouter
from pilot.component import SystemApp, BaseComponent, ComponentType
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
class HttpTriggerManager(TriggerManager):
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
) -> None:
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
class DefaultTriggerManager(TriggerManager, BaseComponent):
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger)
else:
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)

View File

@ -54,6 +54,8 @@ class ComponentType(str, Enum):
TRACER = "dbgpt_tracer" TRACER = "dbgpt_tracer"
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage" TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default" RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
class BaseComponent(LifeCycle, ABC): class BaseComponent(LifeCycle, ABC):

View File

@ -16,6 +16,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts") FONT_DIR = os.path.join(PILOT_PATH, "fonts")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
current_directory = os.getcwd() current_directory = os.getcwd()

View File

@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
logger.error("init create conversation log error" + str(e)) logger.error("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None: def append(self, once_message: OnceConversation) -> None:
logger.info(f"db history append: {once_message}") logger.debug(f"db history append: {once_message}")
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id self.chat_seesion_id
) )

View File

@ -7,9 +7,10 @@ from pilot.awel import (
MapOperator, MapOperator,
TransformStreamAbsOperator, TransformStreamAbsOperator,
) )
from pilot.component import ComponentType
from pilot.awel.operator.base import BaseOperator from pilot.awel.operator.base import BaseOperator
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager from pilot.model.cluster import WorkerManager, WorkerManagerFactory
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
streamify: Asynchronously processes a stream of inputs, yielding model outputs. streamify: Asynchronously processes a stream of inputs, yielding model outputs.
""" """
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.worker_manager = worker_manager self.worker_manager = worker_manager
@ -42,6 +43,10 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
Returns: Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
""" """
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for out in self.worker_manager.generate_stream(input_value): async for out in self.worker_manager.generate_stream(input_value):
yield out yield out
@ -57,9 +62,9 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
map: Asynchronously processes a single input and returns the model output. map: Asynchronously processes a single input and returns the model output.
""" """
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
self.worker_manager = worker_manager
super().__init__(**kwargs) super().__init__(**kwargs)
self.worker_manager = worker_manager
async def map(self, input_value: Dict) -> ModelOutput: async def map(self, input_value: Dict) -> ModelOutput:
"""Process a single input and return the model output. """Process a single input and return the model output.
@ -70,6 +75,10 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
Returns: Returns:
ModelOutput: The output from the model. ModelOutput: The output from the model.
""" """
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return await self.worker_manager.generate(input_value) return await self.worker_manager.generate(input_value)

View File

@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend payloads["model"] = proxyllm_backend
logger.info( logger.info(f"Send request to real model {proxyllm_backend}")
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
return history, payloads return history, payloads

View File

@ -68,7 +68,7 @@ class BaseChat(ABC):
CFG.prompt_template_registry.get_prompt_template( CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(), self.chat_mode.value(),
language=CFG.LANGUAGE, language=CFG.LANGUAGE,
model_name=CFG.LLM_MODEL, model_name=self.llm_model,
proxyllm_backend=CFG.PROXYLLM_BACKEND, proxyllm_backend=CFG.PROXYLLM_BACKEND,
) )
) )
@ -141,13 +141,7 @@ class BaseChat(ABC):
return speak_to_user return speak_to_user
async def __call_base(self): async def __call_base(self):
import inspect input_values = await self.generate_input_values()
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
@ -379,16 +373,18 @@ class BaseChat(ABC):
if self.prompt_template.template_define: if self.prompt_template.template_define:
text += self.prompt_template.template_define + self.prompt_template.sep text += self.prompt_template.template_define + self.prompt_template.sep
### Load prompt ### Load prompt
text += self.__load_system_message() text += _load_system_message(self.current_message, self.prompt_template)
### Load examples ### Load examples
text += self.__load_example_messages() text += _load_example_messages(self.prompt_template)
### Load History ### Load History
text += self.__load_history_messages() text += _load_history_messages(
self.prompt_template, self.history_message, self.chat_retention_rounds
)
### Load User Input ### Load User Input
text += self.__load_user_message() text += _load_user_message(self.current_message, self.prompt_template)
return text return text
def generate_llm_messages(self) -> List[ModelMessage]: def generate_llm_messages(self) -> List[ModelMessage]:
@ -406,137 +402,26 @@ class BaseChat(ABC):
) )
) )
### Load prompt ### Load prompt
messages += self.__load_system_message(str_message=False) messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples ### Load examples
messages += self.__load_example_messages(str_message=False) messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History ### Load History
messages += self.__load_history_messages(str_message=False) messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input ### Load User Input
messages += self.__load_user_message(str_message=False) messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages return messages
def __load_system_message(self, str_message: bool = True):
system_convs = self.current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def __load_user_message(self, str_message: bool = True):
user_conv = self.current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self, str_message: bool = True):
example_text = ""
example_messages = []
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def __load_history_messages(self, str_message: bool = True):
history_text = ""
history_messages = []
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if self.chat_retention_rounds > 1:
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(
role=message_type, content=message_content
)
)
else:
### user all history
for conversation in self.history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages
def current_ai_response(self) -> str: def current_ai_response(self) -> str:
for message in self.current_message.messages: for message in self.current_message.messages:
if message.type == "view": if message.type == "view":
@ -656,3 +541,127 @@ def _build_model_operator(
cache_check_branch_node >> cached_node >> join_node cache_check_branch_node >> cached_node >> join_node
return join_node return join_node
def _load_system_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
system_convs = current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def _load_user_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
user_conv = current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
example_text = ""
example_messages = []
if prompt_template.example_selector:
for round_conv in prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type + ":" + message_content + prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def _load_history_messages(
prompt_template: PromptTemplate,
history_message: List[OnceConversation],
chat_retention_rounds: int,
str_message: bool = True,
):
history_text = ""
history_messages = []
if prompt_template.need_historical_messages:
if history_message:
logger.info(
f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
)
if len(history_message) > chat_retention_rounds:
for first_message in history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if chat_retention_rounds > 1:
index = chat_retention_rounds - 1
for round_conv in history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
else:
### user all history
for conversation in history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages

View File

@ -117,6 +117,10 @@ class ModelMessage(BaseModel):
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages)) return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
class Generation(BaseModel): class Generation(BaseModel):
"""Output of a single generation.""" """Output of a single generation."""

View File

@ -6,7 +6,6 @@ import re
import sqlparse import sqlparse
import pandas as pd import pandas as pd
import chardet import chardet
import pandas as pd
import numpy as np import numpy as np
from pyparsing import ( from pyparsing import (
CaselessKeyword, CaselessKeyword,
@ -27,6 +26,8 @@ from pyparsing import (
from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.pd_utils import csv_colunm_foramt
from pilot.common.string_utils import is_chinese_include_number from pilot.common.string_utils import is_chinese_include_number
logger = logging.getLogger(__name__)
def excel_colunm_format(old_name: str) -> str: def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip() new_column = old_name.strip()
@ -263,7 +264,7 @@ class ExcelReader:
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
self.file_name_without_extension = os.path.splitext(file_name)[0] self.file_name_without_extension = os.path.splitext(file_name)[0]
encoding, confidence = detect_encoding(file_path) encoding, confidence = detect_encoding(file_path)
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})") logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
self.excel_file_name = file_name self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1] self.extension = os.path.splitext(file_name)[1]
# read excel file # read excel file
@ -323,7 +324,7 @@ class ExcelReader:
colunms.append(descrip[0]) colunms.append(descrip[0])
return colunms, results.fetchall() return colunms, results.fetchall()
except Exception as e: except Exception as e:
logging.error("excel sql run error!", e) logger.error(f"excel sql run error!, {str(e)}")
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}") raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql): def get_df_by_sql_ex(self, sql):

View File

@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)
logging.info("clean prompt response:", clean_str) logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model # Compatible with community pure sql output model
if self.is_sql_statement(clean_str): if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "") return SqlAction(clean_str, "")
@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key] thoughts = response[key]
return SqlAction(sql, thoughts) return SqlAction(sql, thoughts)
except Exception as e: except Exception as e:
logging.error("json load faild") logger.error("json load faild")
return SqlAction("", clean_str) return SqlAction("", clean_str)
def parse_view_response(self, speak, data, prompt_response) -> str: def parse_view_response(self, speak, data, prompt_response) -> str:

View File

@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat):
self.existing_answer = chat_param["select_param"] self.existing_answer = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
# "context": self.user_input, # "context": self.user_input,
"existing_answer": self.existing_answer, "existing_answer": self.existing_answer,

View File

@ -23,7 +23,7 @@ class ExtractSummary(BaseChat):
self.user_input = chat_param["select_param"] self.user_input = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"context": self.user_input, "context": self.user_input,
} }

View File

@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat):
self.current_user_input, self.current_user_input,
self.top_k, self.top_k,
) )
self.sources = self.merge_by_key( self.sources = _merge_by_key(
list(map(lambda doc: doc.metadata, docs)), "source" list(map(lambda doc: doc.metadata, docs)), "source"
) )
@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat):
) )
return html return html
def merge_by_key(self, data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value() return ChatScene.ChatKnowledge.value()
@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat):
def get_space_context(self, space_name): def get_space_context(self, space_name):
service = KnowledgeService() service = KnowledgeService()
return service.get_space_context(space_name) return service.get_space_context(space_name)
def _merge_by_key(data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())

View File

View File

@ -0,0 +1,255 @@
from typing import Dict, Optional, List, Any
from dataclasses import dataclass
import datetime
import os
from pilot.awel import MapOperator
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.message import OnceConversation
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
# TODO move global config
CFG = Config()
@dataclass
class ChatContext:
current_user_input: str
model_name: Optional[str]
chat_session_id: Optional[str] = None
select_param: Optional[str] = None
chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
prompt_template: Optional[PromptTemplate] = None
chat_retention_rounds: Optional[int] = 0
history_storage: Optional[BaseChatHistoryMemory] = None
history_manager: Optional["ChatHistoryManager"] = None
# The input values for prompt template
input_values: Optional[Dict] = None
echo: Optional[bool] = False
def build_model_payload(self) -> Dict:
if not self.input_values:
raise ValueError("The input value can't be empty")
llm_messages = self.history_manager._new_chat(self.input_values)
return {
"model": self.model_name,
"prompt": "",
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.echo,
}
class ChatHistoryManager:
def __init__(
self,
chat_ctx: ChatContext,
prompt_template: PromptTemplate,
history_storage: BaseChatHistoryMemory,
chat_retention_rounds: Optional[int] = 0,
) -> None:
self._chat_ctx = chat_ctx
self.chat_retention_rounds = chat_retention_rounds
self.current_message: OnceConversation = OnceConversation(
chat_ctx.chat_scene.value()
)
self.prompt_template = prompt_template
self.history_storage: BaseChatHistoryMemory = history_storage
self.history_message: List[OnceConversation] = history_storage.messages()
self.current_message.model_name = chat_ctx.model_name
if chat_ctx.select_param:
if len(chat_ctx.chat_scene.param_types()) > 0:
self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
self.current_message.param_value = chat_ctx.select_param
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self._chat_ctx.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)
return self._generate_llm_messages()
def _generate_llm_messages(self) -> List[ModelMessage]:
from pilot.scene.base_chat import (
_load_system_message,
_load_example_messages,
_load_history_messages,
_load_user_message,
)
messages = []
### Load scene setting or character definition as system message
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role=ModelMessageRoleType.SYSTEM,
content=self.prompt_template.template_define,
)
)
### Load prompt
messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples
messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History
messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input
messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages
class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
super().__init__(**kwargs)
self._prompt_template = prompt_template
async def map(self, input_value: ChatContext) -> ChatContext:
if not self._prompt_template:
self._prompt_template: PromptTemplate = (
CFG.prompt_template_registry.get_prompt_template(
input_value.chat_scene.value(),
language=CFG.LANGUAGE,
model_name=input_value.model_name,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
input_value.prompt_template = self._prompt_template
return input_value
class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
if self._history:
return self._history
chat_history_fac = ChatHistory()
input_value.history_storage = chat_history_fac.get_store_instance(
input_value.chat_session_id
)
return input_value
class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
history_storage = self._history or input_value.history_storage
if not history_storage:
from pilot.memory.chat_history.store_type.mem_history import (
MemHistoryMemory,
)
history_storage = MemHistoryMemory(input_value.chat_session_id)
input_value.history_storage = history_storage
input_value.history_manager = ChatHistoryManager(
input_value,
input_value.prompt_template,
history_storage,
input_value.chat_retention_rounds,
)
return input_value
class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> ChatContext:
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.scene.chat_knowledge.v1.chat import _merge_by_key
# TODO, decompose the current operator into some atomic operators
knowledge_space = input_value.select_param
vector_store_config = {
"vector_store_name": knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
space_context = await self._get_space_context(knowledge_space)
top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if space_context is None
else int(space_context["embedding"]["topk"])
)
max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if space_context is None or space_context.get("prompt") is None
else int(space_context["prompt"]["max_token"])
)
input_value.prompt_template.template_is_strict = False
if space_context and space_context.get("prompt"):
input_value.prompt_template.template_define = space_context["prompt"][
"scene"
]
input_value.prompt_template.template = space_context["prompt"]["template"]
docs = await self.blocking_func_to_async(
knowledge_embedding_client.similar_search,
input_value.current_user_input,
top_k,
)
sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source")
if not docs or len(docs) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
context = [d.page_content for d in docs]
context = context[:max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
)
input_value.input_values = {
"context": context,
"question": input_value.current_user_input,
"relations": relations,
}
return input_value
async def _get_space_context(self, space_name):
from pilot.server.knowledge.service import KnowledgeService
service = KnowledgeService()
return await self.blocking_func_to_async(service.get_space_context, space_name)
class BaseChatOperator(MapOperator[ChatContext, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> Dict:
return input_value.build_model_payload()

View File

@ -45,6 +45,7 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path param, system_app, embedding_model_name, embedding_model_path
) )
_initialize_model_cache(system_app) _initialize_model_cache(system_app)
_initialize_awel(system_app)
def _initialize_embedding_model( def _initialize_embedding_model(
@ -149,3 +150,10 @@ def _initialize_model_cache(system_app: SystemApp):
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
def _initialize_awel(system_app: SystemApp):
from pilot.awel import initialize_awel
from pilot.configs.model_config import _DAG_DEFINITION_DIR
initialize_awel(system_app, _DAG_DEFINITION_DIR)

View File

@ -1,9 +1,14 @@
import argparse import argparse
import os import os
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
from typing import Any, List, Optional, Type, Union, Callable, Dict from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
from collections import OrderedDict from collections import OrderedDict
if TYPE_CHECKING:
from pydantic import BaseModel
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
@dataclass @dataclass
class ParameterDescription: class ParameterDescription:
@ -613,6 +618,64 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
return default_value return default_value
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
import pydantic
version = int(pydantic.VERSION.split(".")[0])
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
required_fields = set(schema.get("required", []))
param_descs = []
for field_name, field_schema in schema.get("properties", {}).items():
field = model_cls.model_fields[field_name]
param_type = field_schema.get("type")
if not param_type and "anyOf" in field_schema:
for any_of in field_schema["anyOf"]:
if any_of["type"] != "null":
param_type = any_of["type"]
break
if version >= 2:
default_value = (
field.default
if hasattr(field, "default")
and str(field.default) != "PydanticUndefined"
else None
)
else:
default_value = (
field.default
if not field.allow_none
else (
field.default_factory() if callable(field.default_factory) else None
)
)
description = field_schema.get("description", "")
is_required = field_name in required_fields
valid_values = None
ext_metadata = None
if hasattr(field, "field_info"):
valid_values = (
list(field.field_info.choices)
if hasattr(field.field_info, "choices")
else None
)
ext_metadata = (
field.field_info.extra if hasattr(field.field_info, "extra") else None
)
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
param_desc = ParameterDescription(
param_class=param_class,
param_name=field_name,
param_type=param_type,
default_value=default_value,
description=description,
required=is_required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
param_descs.append(param_desc)
return param_descs
class _SimpleArgParser: class _SimpleArgParser:
def __init__(self, *args): def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args} self.params = {arg.replace("_", "-"): None for arg in args}