feat(awel): AWEL supports http trigger and add some AWEL examples (#815)

- AWEL supports http trigger.
- Disassemble the KBQA into some atomic operator.
- Add some AWEL examples.
This commit is contained in:
Aries-ckt 2023-11-21 15:31:28 +08:00 committed by GitHub
commit 2ce77519de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1199 additions and 187 deletions

View File

@ -0,0 +1,54 @@
"""AWEL: Simple chat dag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"user_input": "hello"
}'
"""
from typing import Dict
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
from pilot.scene.base_message import ModelMessage
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class TriggerReqBody(BaseModel):
model: str = Field(..., description="Model name")
user_input: str = Field(..., description="User input")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
hist = []
hist.append(ModelMessage.build_human_message(input_value.user_input))
hist = list(h.dict() for h in hist)
params = {
"prompt": input_value.user_input,
"messages": hist,
"model": input_value.model,
"echo": False,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_dag_example") as dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_chat", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = ModelOperator()
# type(out) == ModelOutput
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task

View File

@ -0,0 +1,32 @@
"""AWEL: Simple dag example
Example:
.. code-block:: shell
curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan
"""
from pydantic import BaseModel, Field
from pilot.awel import DAG, HttpTrigger, MapOperator
class TriggerReqBody(BaseModel):
name: str = Field(..., description="User name")
age: int = Field(18, description="User age")
class RequestHandleOperator(MapOperator[TriggerReqBody, str]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> str:
print(f"Receive input value: {input_value}")
return f"Hello, {input_value.name}, your age is {input_value.age}"
with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator()
trigger >> map_node

View File

@ -0,0 +1,70 @@
"""AWEL: Simple rag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
-H "Content-Type: application/json" -d '{
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
"model_name": "proxyllm",
"chat_mode": "chat_knowledge",
"user_input": "What is DB-GPT?",
"select_param": "default"
}'
"""
from pilot.awel import HttpTrigger, DAG, MapOperator
from pilot.scene.operator._experimental import (
ChatContext,
PromptManagerOperator,
ChatHistoryStorageOperator,
ChatHistoryOperator,
EmbeddingEngingOperator,
BaseChatOperator,
)
from pilot.scene.base import ChatScene
from pilot.openapi.api_view_model import ConversationVo
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ConversationVo) -> ChatContext:
return ChatContext(
current_user_input=input_value.user_input,
model_name=input_value.model_name,
chat_session_id=input_value.conv_uid,
select_param=input_value.select_param,
chat_scene=ChatScene.ChatKnowledge,
)
with DAG("simple_rag_example") as dag:
trigger_task = HttpTrigger(
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
embedding_task = EmbeddingEngingOperator()
chat_task = BaseChatOperator()
model_task = ModelOperator()
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
(
trigger_task
>> req_parse_task
>> prompt_task
>> history_storage_task
>> history_task
>> embedding_task
>> chat_task
>> model_task
>> output_parser_task
)

View File

@ -1,8 +1,17 @@
"""Agentic Workflow Expression Language (AWEL)"""
"""Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""
from pilot.component import SystemApp
from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
@ -28,6 +37,7 @@ from .task.task_impl import (
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
@ -57,4 +67,21 @@ __all__ = [
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()

7
pilot/awel/base.py Normal file
View File

@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -1,14 +1,20 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache
from concurrent.futures import Executor
from pilot.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
@ -96,6 +102,8 @@ class DependencyMixin(ABC):
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
@classmethod
def enter_dag(cls, dag) -> None:
@ -138,18 +146,48 @@ class DAGVar:
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
if not cls._system_app:
raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
@ -159,6 +197,10 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str:
return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@ -178,7 +220,7 @@ class DAGNode(DependencyMixin, ABC):
return self._node_name
@property
def dag(self) -> "DAGNode":
def dag(self) -> "DAG":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
@ -254,17 +296,69 @@ class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from pilot.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

93
pilot/awel/dag/loader.py Normal file
View File

@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

@ -14,6 +14,13 @@ from typing import (
)
import functools
from inspect import signature
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import (
@ -67,6 +74,19 @@ class BaseOperatorMeta(ABCMeta):
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id")
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
@ -80,6 +100,10 @@ class BaseOperatorMeta(ABCMeta):
kwargs["task_id"] = task_id
if not kwargs.get("runner"):
kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs)
return real_obj
@ -171,7 +195,12 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_awel(runner: WorkflowRunner):
def initialize_runner(runner: WorkflowRunner):
global default_runner
default_runner = runner

View File

@ -237,3 +237,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
def __init__(self, **kwargs) -> None:
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@ -3,7 +3,7 @@ import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[node.node_id] = task_ctx
return
try:
logger.info(
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info(
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -0,0 +1,117 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from starlette.requests import Request
from starlette.responses import Response
from pydantic import BaseModel
import logging
from .base import Trigger
from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self._endpoint = endpoint
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
async def trigger(self) -> None:
pass
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
route_function.__name__ = name
return route_function
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
self._endpoint,
methods=methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
return request

View File

@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
if TYPE_CHECKING:
from fastapi import APIRouter
from pilot.component import SystemApp, BaseComponent, ComponentType
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
class HttpTriggerManager(TriggerManager):
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
) -> None:
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
class DefaultTriggerManager(TriggerManager, BaseComponent):
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger)
else:
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)

View File

@ -54,6 +54,8 @@ class ComponentType(str, Enum):
TRACER = "dbgpt_tracer"
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
class BaseComponent(LifeCycle, ABC):

View File

@ -16,6 +16,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
current_directory = os.getcwd()

View File

@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
logger.error("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None:
logger.info(f"db history append: {once_message}")
logger.debug(f"db history append: {once_message}")
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)

View File

@ -7,9 +7,10 @@ from pilot.awel import (
MapOperator,
TransformStreamAbsOperator,
)
from pilot.component import ComponentType
from pilot.awel.operator.base import BaseOperator
from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager
from pilot.model.cluster import WorkerManager, WorkerManagerFactory
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
@ -42,6 +43,10 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for out in self.worker_manager.generate_stream(input_value):
yield out
@ -57,9 +62,9 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
map: Asynchronously processes a single input and returns the model output.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
self.worker_manager = worker_manager
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
async def map(self, input_value: Dict) -> ModelOutput:
"""Process a single input and return the model output.
@ -70,6 +75,10 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
Returns:
ModelOutput: The output from the model.
"""
if not self.worker_manager:
self.worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return await self.worker_manager.generate(input_value)

View File

@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
logger.info(f"Send request to real model {proxyllm_backend}")
return history, payloads

View File

@ -68,7 +68,7 @@ class BaseChat(ABC):
CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(),
language=CFG.LANGUAGE,
model_name=CFG.LLM_MODEL,
model_name=self.llm_model,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
@ -141,13 +141,7 @@ class BaseChat(ABC):
return speak_to_user
async def __call_base(self):
import inspect
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
input_values = await self.generate_input_values()
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
@ -379,16 +373,18 @@ class BaseChat(ABC):
if self.prompt_template.template_define:
text += self.prompt_template.template_define + self.prompt_template.sep
### Load prompt
text += self.__load_system_message()
text += _load_system_message(self.current_message, self.prompt_template)
### Load examples
text += self.__load_example_messages()
text += _load_example_messages(self.prompt_template)
### Load History
text += self.__load_history_messages()
text += _load_history_messages(
self.prompt_template, self.history_message, self.chat_retention_rounds
)
### Load User Input
text += self.__load_user_message()
text += _load_user_message(self.current_message, self.prompt_template)
return text
def generate_llm_messages(self) -> List[ModelMessage]:
@ -406,137 +402,26 @@ class BaseChat(ABC):
)
)
### Load prompt
messages += self.__load_system_message(str_message=False)
messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples
messages += self.__load_example_messages(str_message=False)
messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History
messages += self.__load_history_messages(str_message=False)
messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input
messages += self.__load_user_message(str_message=False)
messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages
def __load_system_message(self, str_message: bool = True):
system_convs = self.current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def __load_user_message(self, str_message: bool = True):
user_conv = self.current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self, str_message: bool = True):
example_text = ""
example_messages = []
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def __load_history_messages(self, str_message: bool = True):
history_text = ""
history_messages = []
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if self.chat_retention_rounds > 1:
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(
role=message_type, content=message_content
)
)
else:
### user all history
for conversation in self.history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages
def current_ai_response(self) -> str:
for message in self.current_message.messages:
if message.type == "view":
@ -656,3 +541,127 @@ def _build_model_operator(
cache_check_branch_node >> cached_node >> join_node
return join_node
def _load_system_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
system_convs = current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def _load_user_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
user_conv = current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
example_text = ""
example_messages = []
if prompt_template.example_selector:
for round_conv in prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type + ":" + message_content + prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def _load_history_messages(
prompt_template: PromptTemplate,
history_message: List[OnceConversation],
chat_retention_rounds: int,
str_message: bool = True,
):
history_text = ""
history_messages = []
if prompt_template.need_historical_messages:
if history_message:
logger.info(
f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
)
if len(history_message) > chat_retention_rounds:
for first_message in history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if chat_retention_rounds > 1:
index = chat_retention_rounds - 1
for round_conv in history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
else:
### user all history
for conversation in history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages

View File

@ -117,6 +117,10 @@ class ModelMessage(BaseModel):
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
class Generation(BaseModel):
"""Output of a single generation."""

View File

@ -6,7 +6,6 @@ import re
import sqlparse
import pandas as pd
import chardet
import pandas as pd
import numpy as np
from pyparsing import (
CaselessKeyword,
@ -27,6 +26,8 @@ from pyparsing import (
from pilot.common.pd_utils import csv_colunm_foramt
from pilot.common.string_utils import is_chinese_include_number
logger = logging.getLogger(__name__)
def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip()
@ -263,7 +264,7 @@ class ExcelReader:
file_name = os.path.basename(file_path)
self.file_name_without_extension = os.path.splitext(file_name)[0]
encoding, confidence = detect_encoding(file_path)
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1]
# read excel file
@ -323,7 +324,7 @@ class ExcelReader:
colunms.append(descrip[0])
return colunms, results.fetchall()
except Exception as e:
logging.error("excel sql run error!", e)
logger.error(f"excel sql run error!, {str(e)}")
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql):

View File

@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
logging.info("clean prompt response:", clean_str)
logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model
if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "")
@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key]
return SqlAction(sql, thoughts)
except Exception as e:
logging.error("json load faild")
logger.error("json load faild")
return SqlAction("", clean_str)
def parse_view_response(self, speak, data, prompt_response) -> str:

View File

@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
"text": self.user_input,
}

View File

@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
"text": self.user_input,
}

View File

@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat):
self.existing_answer = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
# "context": self.user_input,
"existing_answer": self.existing_answer,

View File

@ -23,7 +23,7 @@ class ExtractSummary(BaseChat):
self.user_input = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
"context": self.user_input,
}

View File

@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat):
self.current_user_input,
self.top_k,
)
self.sources = self.merge_by_key(
self.sources = _merge_by_key(
list(map(lambda doc: doc.metadata, docs)), "source"
)
@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat):
)
return html
def merge_by_key(self, data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()
@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat):
def get_space_context(self, space_name):
service = KnowledgeService()
return service.get_space_context(space_name)
def _merge_by_key(data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())

View File

View File

@ -0,0 +1,255 @@
from typing import Dict, Optional, List, Any
from dataclasses import dataclass
import datetime
import os
from pilot.awel import MapOperator
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.message import OnceConversation
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
# TODO move global config
CFG = Config()
@dataclass
class ChatContext:
current_user_input: str
model_name: Optional[str]
chat_session_id: Optional[str] = None
select_param: Optional[str] = None
chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
prompt_template: Optional[PromptTemplate] = None
chat_retention_rounds: Optional[int] = 0
history_storage: Optional[BaseChatHistoryMemory] = None
history_manager: Optional["ChatHistoryManager"] = None
# The input values for prompt template
input_values: Optional[Dict] = None
echo: Optional[bool] = False
def build_model_payload(self) -> Dict:
if not self.input_values:
raise ValueError("The input value can't be empty")
llm_messages = self.history_manager._new_chat(self.input_values)
return {
"model": self.model_name,
"prompt": "",
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.echo,
}
class ChatHistoryManager:
def __init__(
self,
chat_ctx: ChatContext,
prompt_template: PromptTemplate,
history_storage: BaseChatHistoryMemory,
chat_retention_rounds: Optional[int] = 0,
) -> None:
self._chat_ctx = chat_ctx
self.chat_retention_rounds = chat_retention_rounds
self.current_message: OnceConversation = OnceConversation(
chat_ctx.chat_scene.value()
)
self.prompt_template = prompt_template
self.history_storage: BaseChatHistoryMemory = history_storage
self.history_message: List[OnceConversation] = history_storage.messages()
self.current_message.model_name = chat_ctx.model_name
if chat_ctx.select_param:
if len(chat_ctx.chat_scene.param_types()) > 0:
self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
self.current_message.param_value = chat_ctx.select_param
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self._chat_ctx.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)
return self._generate_llm_messages()
def _generate_llm_messages(self) -> List[ModelMessage]:
from pilot.scene.base_chat import (
_load_system_message,
_load_example_messages,
_load_history_messages,
_load_user_message,
)
messages = []
### Load scene setting or character definition as system message
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role=ModelMessageRoleType.SYSTEM,
content=self.prompt_template.template_define,
)
)
### Load prompt
messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples
messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History
messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input
messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages
class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
super().__init__(**kwargs)
self._prompt_template = prompt_template
async def map(self, input_value: ChatContext) -> ChatContext:
if not self._prompt_template:
self._prompt_template: PromptTemplate = (
CFG.prompt_template_registry.get_prompt_template(
input_value.chat_scene.value(),
language=CFG.LANGUAGE,
model_name=input_value.model_name,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
input_value.prompt_template = self._prompt_template
return input_value
class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
if self._history:
return self._history
chat_history_fac = ChatHistory()
input_value.history_storage = chat_history_fac.get_store_instance(
input_value.chat_session_id
)
return input_value
class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
history_storage = self._history or input_value.history_storage
if not history_storage:
from pilot.memory.chat_history.store_type.mem_history import (
MemHistoryMemory,
)
history_storage = MemHistoryMemory(input_value.chat_session_id)
input_value.history_storage = history_storage
input_value.history_manager = ChatHistoryManager(
input_value,
input_value.prompt_template,
history_storage,
input_value.chat_retention_rounds,
)
return input_value
class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> ChatContext:
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.scene.chat_knowledge.v1.chat import _merge_by_key
# TODO, decompose the current operator into some atomic operators
knowledge_space = input_value.select_param
vector_store_config = {
"vector_store_name": knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
space_context = await self._get_space_context(knowledge_space)
top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if space_context is None
else int(space_context["embedding"]["topk"])
)
max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if space_context is None or space_context.get("prompt") is None
else int(space_context["prompt"]["max_token"])
)
input_value.prompt_template.template_is_strict = False
if space_context and space_context.get("prompt"):
input_value.prompt_template.template_define = space_context["prompt"][
"scene"
]
input_value.prompt_template.template = space_context["prompt"]["template"]
docs = await self.blocking_func_to_async(
knowledge_embedding_client.similar_search,
input_value.current_user_input,
top_k,
)
sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source")
if not docs or len(docs) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
context = [d.page_content for d in docs]
context = context[:max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
)
input_value.input_values = {
"context": context,
"question": input_value.current_user_input,
"relations": relations,
}
return input_value
async def _get_space_context(self, space_name):
from pilot.server.knowledge.service import KnowledgeService
service = KnowledgeService()
return await self.blocking_func_to_async(service.get_space_context, space_name)
class BaseChatOperator(MapOperator[ChatContext, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> Dict:
return input_value.build_model_payload()

View File

@ -45,6 +45,7 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
def _initialize_embedding_model(
@ -149,3 +150,10 @@ def _initialize_model_cache(system_app: SystemApp):
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
def _initialize_awel(system_app: SystemApp):
from pilot.awel import initialize_awel
from pilot.configs.model_config import _DAG_DEFINITION_DIR
initialize_awel(system_app, _DAG_DEFINITION_DIR)

View File

@ -1,9 +1,14 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
from typing import Any, List, Optional, Type, Union, Callable, Dict
from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
from collections import OrderedDict
if TYPE_CHECKING:
from pydantic import BaseModel
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
@dataclass
class ParameterDescription:
@ -613,6 +618,64 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
return default_value
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
import pydantic
version = int(pydantic.VERSION.split(".")[0])
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
required_fields = set(schema.get("required", []))
param_descs = []
for field_name, field_schema in schema.get("properties", {}).items():
field = model_cls.model_fields[field_name]
param_type = field_schema.get("type")
if not param_type and "anyOf" in field_schema:
for any_of in field_schema["anyOf"]:
if any_of["type"] != "null":
param_type = any_of["type"]
break
if version >= 2:
default_value = (
field.default
if hasattr(field, "default")
and str(field.default) != "PydanticUndefined"
else None
)
else:
default_value = (
field.default
if not field.allow_none
else (
field.default_factory() if callable(field.default_factory) else None
)
)
description = field_schema.get("description", "")
is_required = field_name in required_fields
valid_values = None
ext_metadata = None
if hasattr(field, "field_info"):
valid_values = (
list(field.field_info.choices)
if hasattr(field.field_info, "choices")
else None
)
ext_metadata = (
field.field_info.extra if hasattr(field.field_info, "extra") else None
)
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
param_desc = ParameterDescription(
param_class=param_class,
param_name=field_name,
param_type=param_type,
default_value=default_value,
description=description,
required=is_required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
param_descs.append(param_desc)
return param_descs
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}