chore: Merge latest code

This commit is contained in:
Fangyin Cheng
2024-08-30 15:02:53 +08:00
parent bf63a967b5
commit 0e71991f7e
8 changed files with 169 additions and 22 deletions

View File

@@ -619,6 +619,7 @@ class DAGContext:
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
self._event_loop_task_id = event_loop_task_id
self._dag_variables = dag_variables
self._share_data_lock = asyncio.Lock()
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
@@ -680,8 +681,9 @@ class DAGContext:
Returns:
Any: The share data, you can cast it to the real type
"""
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async with self._share_data_lock:
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
@@ -694,10 +696,11 @@ class DAGContext:
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async with self._share_data_lock:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key.

View File

@@ -687,6 +687,10 @@ class IOField(Resource):
" True",
examples=[0, 1, 2],
)
mappers: Optional[List[str]] = Field(
default=None,
description="The mappers of the field, transform the field to the target type",
)
@classmethod
def build_from(
@@ -698,10 +702,16 @@ class IOField(Resource):
is_list: bool = False,
dynamic: bool = False,
dynamic_minimum: int = 0,
mappers: Optional[Union[Type, List[Type]]] = None,
):
"""Build the resource from the type."""
type_name = type.__qualname__
type_cls = _get_type_name(type)
# TODO: Check the mapper instance can be created without required
# parameters.
if mappers and not isinstance(mappers, list):
mappers = [mappers]
mappers_cls = [_get_type_name(m) for m in mappers] if mappers else None
return cls(
label=label,
name=name,
@@ -711,6 +721,7 @@ class IOField(Resource):
description=description or label,
dynamic=dynamic,
dynamic_minimum=dynamic_minimum,
mappers=mappers_cls,
)
@model_validator(mode="before")

View File

@@ -1,10 +1,11 @@
"""Build AWEL DAGs from serialized data."""
import dataclasses
import logging
import uuid
from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from typing_extensions import Annotated
@@ -565,6 +566,17 @@ class FlowPanel(BaseModel):
return [FlowVariables(**v) for v in variables]
@dataclasses.dataclass
class _KeyToNodeItem:
"""Key to node item."""
key: str
source_order: int
target_order: int
mappers: List[str]
edge_index: int
class FlowFactory:
"""Flow factory."""
@@ -580,8 +592,10 @@ class FlowFactory:
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
key_to_resource: Dict[str, ResourceMetadata] = {}
key_to_downstream: Dict[str, List[Tuple[str, int, int]]] = {}
key_to_upstream: Dict[str, List[Tuple[str, int, int]]] = {}
# Record current node's downstream
key_to_downstream: Dict[str, List[_KeyToNodeItem]] = {}
# Record current node's upstream
key_to_upstream: Dict[str, List[_KeyToNodeItem]] = {}
key_to_upstream_node: Dict[str, List[FlowNodeData]] = {}
for node in flow_data.nodes:
key = node.id
@@ -600,7 +614,7 @@ class FlowFactory:
"No operator or resource nodes found in the flow."
)
for edge in flow_data.edges:
for edge_index, edge in enumerate(flow_data.edges):
source_key = edge.source
target_key = edge.target
source_node: FlowNodeData | None = key_to_operator_nodes.get(
@@ -620,12 +634,37 @@ class FlowFactory:
if source_node.data.is_operator and target_node.data.is_operator:
# Operator to operator.
mappers = []
for i, out in enumerate(source_node.data.outputs):
if i != edge.source_order:
continue
if out.mappers:
# Current edge is a mapper edge, find the mappers.
mappers = out.mappers
# Note: Not support mappers in the inputs of the target node now.
downstream = key_to_downstream.get(source_key, [])
downstream.append((target_key, edge.source_order, edge.target_order))
downstream.append(
_KeyToNodeItem(
key=target_key,
source_order=edge.source_order,
target_order=edge.target_order,
mappers=mappers,
edge_index=edge_index,
)
)
key_to_downstream[source_key] = downstream
upstream = key_to_upstream.get(target_key, [])
upstream.append((source_key, edge.source_order, edge.target_order))
upstream.append(
_KeyToNodeItem(
key=source_key,
source_order=edge.source_order,
target_order=edge.target_order,
mappers=mappers,
edge_index=edge_index,
)
)
key_to_upstream[target_key] = upstream
elif not source_node.data.is_operator and target_node.data.is_operator:
# Resource to operator.
@@ -683,10 +722,10 @@ class FlowFactory:
# Sort the keys by the order of the nodes.
for key, value in key_to_downstream.items():
# Sort by source_order.
key_to_downstream[key] = sorted(value, key=lambda x: x[1])
key_to_downstream[key] = sorted(value, key=lambda x: x.source_order)
for key, value in key_to_upstream.items():
# Sort by target_order.
key_to_upstream[key] = sorted(value, key=lambda x: x[2])
key_to_upstream[key] = sorted(value, key=lambda x: x.target_order)
sorted_key_to_resource_nodes = list(key_to_resource_nodes.values())
sorted_key_to_resource_nodes = sorted(
@@ -784,8 +823,8 @@ class FlowFactory:
self,
flow_panel: FlowPanel,
key_to_tasks: Dict[str, DAGNode],
key_to_downstream: Dict[str, List[Tuple[str, int, int]]],
key_to_upstream: Dict[str, List[Tuple[str, int, int]]],
key_to_downstream: Dict[str, List[_KeyToNodeItem]],
key_to_upstream: Dict[str, List[_KeyToNodeItem]],
dag_id: Optional[str] = None,
) -> DAG:
"""Build the DAG."""
@@ -832,7 +871,8 @@ class FlowFactory:
# This upstream has been sorted according to the order in the downstream
# So we just need to connect the task to the upstream.
for upstream_key, _, _ in upstream:
for up_item in upstream:
upstream_key = up_item.key
# Just one direction.
upstream_task = key_to_tasks.get(upstream_key)
if not upstream_task:
@@ -843,7 +883,13 @@ class FlowFactory:
upstream_task.set_node_id(dag._new_node_id())
if upstream_task is None:
raise ValueError("Unable to find upstream task.")
upstream_task >> task
tasks = _build_mapper_operators(dag, up_item.mappers)
tasks.append(task)
last_task = upstream_task
for t in tasks:
# Connect the task to the upstream task.
last_task >> t
last_task = t
return dag
def pre_load_requirements(self, flow_panel: FlowPanel):
@@ -950,6 +996,23 @@ def _topological_sort(
return key_to_order
def _build_mapper_operators(dag: DAG, mappers: List[str]) -> List[DAGNode]:
from .base import _get_type_cls
tasks = []
for mapper in mappers:
try:
mapper_cls = _get_type_cls(mapper)
task = mapper_cls()
if not task._node_id:
task.set_node_id(dag._new_node_id())
tasks.append(task)
except Exception as e:
err_msg = f"Unable to build mapper task: {mapper}, error: {e}"
raise FlowMetadataException(err_msg)
return tasks
def fill_flow_panel(flow_panel: FlowPanel):
"""Fill the flow panel with the latest metadata.
@@ -978,6 +1041,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
i.mappers = new_param.mappers
for i in node.data.outputs:
if i.name in output_parameters:
new_param = output_parameters[i.name]
@@ -986,6 +1050,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
i.mappers = new_param.mappers
else:
data = cast(ResourceMetadata, node.data)
key = data.get_origin_id()

View File

@@ -945,6 +945,16 @@ class StringHttpTrigger(HttpTrigger):
class CommonLLMHttpTrigger(HttpTrigger):
"""Common LLM http trigger for AWEL."""
class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]):
"""Messages output mapper."""
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
"""Map the request body to messages."""
if isinstance(request_body.messages, str):
return request_body.messages
else:
raise ValueError("Messages to be transformed is not a string")
metadata = ViewMetadata(
label=_("Common LLM Http Trigger"),
name="common_llm_http_trigger",
@@ -965,6 +975,16 @@ class CommonLLMHttpTrigger(HttpTrigger):
"LLM http body"
),
),
IOField.build_from(
_("Request String Messages"),
"request_string_messages",
str,
description=_(
"The request string messages of the API endpoint, parsed from "
"'messages' field of the request body"
),
mappers=[MessagesOutputMapper],
),
],
parameters=[
Parameter.build_from(