mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 04:51:29 +00:00
chore: Merge latest code
This commit is contained in:
@@ -619,6 +619,7 @@ class DAGContext:
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
self._event_loop_task_id = event_loop_task_id
|
||||
self._dag_variables = dag_variables
|
||||
self._share_data_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
@@ -680,8 +681,9 @@ class DAGContext:
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||
return self._share_data.get(key)
|
||||
async with self._share_data_lock:
|
||||
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: bool = False
|
||||
@@ -694,10 +696,11 @@ class DAGContext:
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||
self._share_data[key] = data
|
||||
async with self._share_data_lock:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key.
|
||||
|
@@ -687,6 +687,10 @@ class IOField(Resource):
|
||||
" True",
|
||||
examples=[0, 1, 2],
|
||||
)
|
||||
mappers: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="The mappers of the field, transform the field to the target type",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@@ -698,10 +702,16 @@ class IOField(Resource):
|
||||
is_list: bool = False,
|
||||
dynamic: bool = False,
|
||||
dynamic_minimum: int = 0,
|
||||
mappers: Optional[Union[Type, List[Type]]] = None,
|
||||
):
|
||||
"""Build the resource from the type."""
|
||||
type_name = type.__qualname__
|
||||
type_cls = _get_type_name(type)
|
||||
# TODO: Check the mapper instance can be created without required
|
||||
# parameters.
|
||||
if mappers and not isinstance(mappers, list):
|
||||
mappers = [mappers]
|
||||
mappers_cls = [_get_type_name(m) for m in mappers] if mappers else None
|
||||
return cls(
|
||||
label=label,
|
||||
name=name,
|
||||
@@ -711,6 +721,7 @@ class IOField(Resource):
|
||||
description=description or label,
|
||||
dynamic=dynamic,
|
||||
dynamic_minimum=dynamic_minimum,
|
||||
mappers=mappers_cls,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
@@ -1,10 +1,11 @@
|
||||
"""Build AWEL DAGs from serialized data."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -565,6 +566,17 @@ class FlowPanel(BaseModel):
|
||||
return [FlowVariables(**v) for v in variables]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _KeyToNodeItem:
|
||||
"""Key to node item."""
|
||||
|
||||
key: str
|
||||
source_order: int
|
||||
target_order: int
|
||||
mappers: List[str]
|
||||
edge_index: int
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Flow factory."""
|
||||
|
||||
@@ -580,8 +592,10 @@ class FlowFactory:
|
||||
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource: Dict[str, ResourceMetadata] = {}
|
||||
key_to_downstream: Dict[str, List[Tuple[str, int, int]]] = {}
|
||||
key_to_upstream: Dict[str, List[Tuple[str, int, int]]] = {}
|
||||
# Record current node's downstream
|
||||
key_to_downstream: Dict[str, List[_KeyToNodeItem]] = {}
|
||||
# Record current node's upstream
|
||||
key_to_upstream: Dict[str, List[_KeyToNodeItem]] = {}
|
||||
key_to_upstream_node: Dict[str, List[FlowNodeData]] = {}
|
||||
for node in flow_data.nodes:
|
||||
key = node.id
|
||||
@@ -600,7 +614,7 @@ class FlowFactory:
|
||||
"No operator or resource nodes found in the flow."
|
||||
)
|
||||
|
||||
for edge in flow_data.edges:
|
||||
for edge_index, edge in enumerate(flow_data.edges):
|
||||
source_key = edge.source
|
||||
target_key = edge.target
|
||||
source_node: FlowNodeData | None = key_to_operator_nodes.get(
|
||||
@@ -620,12 +634,37 @@ class FlowFactory:
|
||||
|
||||
if source_node.data.is_operator and target_node.data.is_operator:
|
||||
# Operator to operator.
|
||||
mappers = []
|
||||
for i, out in enumerate(source_node.data.outputs):
|
||||
if i != edge.source_order:
|
||||
continue
|
||||
if out.mappers:
|
||||
# Current edge is a mapper edge, find the mappers.
|
||||
mappers = out.mappers
|
||||
# Note: Not support mappers in the inputs of the target node now.
|
||||
|
||||
downstream = key_to_downstream.get(source_key, [])
|
||||
downstream.append((target_key, edge.source_order, edge.target_order))
|
||||
downstream.append(
|
||||
_KeyToNodeItem(
|
||||
key=target_key,
|
||||
source_order=edge.source_order,
|
||||
target_order=edge.target_order,
|
||||
mappers=mappers,
|
||||
edge_index=edge_index,
|
||||
)
|
||||
)
|
||||
key_to_downstream[source_key] = downstream
|
||||
|
||||
upstream = key_to_upstream.get(target_key, [])
|
||||
upstream.append((source_key, edge.source_order, edge.target_order))
|
||||
upstream.append(
|
||||
_KeyToNodeItem(
|
||||
key=source_key,
|
||||
source_order=edge.source_order,
|
||||
target_order=edge.target_order,
|
||||
mappers=mappers,
|
||||
edge_index=edge_index,
|
||||
)
|
||||
)
|
||||
key_to_upstream[target_key] = upstream
|
||||
elif not source_node.data.is_operator and target_node.data.is_operator:
|
||||
# Resource to operator.
|
||||
@@ -683,10 +722,10 @@ class FlowFactory:
|
||||
# Sort the keys by the order of the nodes.
|
||||
for key, value in key_to_downstream.items():
|
||||
# Sort by source_order.
|
||||
key_to_downstream[key] = sorted(value, key=lambda x: x[1])
|
||||
key_to_downstream[key] = sorted(value, key=lambda x: x.source_order)
|
||||
for key, value in key_to_upstream.items():
|
||||
# Sort by target_order.
|
||||
key_to_upstream[key] = sorted(value, key=lambda x: x[2])
|
||||
key_to_upstream[key] = sorted(value, key=lambda x: x.target_order)
|
||||
|
||||
sorted_key_to_resource_nodes = list(key_to_resource_nodes.values())
|
||||
sorted_key_to_resource_nodes = sorted(
|
||||
@@ -784,8 +823,8 @@ class FlowFactory:
|
||||
self,
|
||||
flow_panel: FlowPanel,
|
||||
key_to_tasks: Dict[str, DAGNode],
|
||||
key_to_downstream: Dict[str, List[Tuple[str, int, int]]],
|
||||
key_to_upstream: Dict[str, List[Tuple[str, int, int]]],
|
||||
key_to_downstream: Dict[str, List[_KeyToNodeItem]],
|
||||
key_to_upstream: Dict[str, List[_KeyToNodeItem]],
|
||||
dag_id: Optional[str] = None,
|
||||
) -> DAG:
|
||||
"""Build the DAG."""
|
||||
@@ -832,7 +871,8 @@ class FlowFactory:
|
||||
|
||||
# This upstream has been sorted according to the order in the downstream
|
||||
# So we just need to connect the task to the upstream.
|
||||
for upstream_key, _, _ in upstream:
|
||||
for up_item in upstream:
|
||||
upstream_key = up_item.key
|
||||
# Just one direction.
|
||||
upstream_task = key_to_tasks.get(upstream_key)
|
||||
if not upstream_task:
|
||||
@@ -843,7 +883,13 @@ class FlowFactory:
|
||||
upstream_task.set_node_id(dag._new_node_id())
|
||||
if upstream_task is None:
|
||||
raise ValueError("Unable to find upstream task.")
|
||||
upstream_task >> task
|
||||
tasks = _build_mapper_operators(dag, up_item.mappers)
|
||||
tasks.append(task)
|
||||
last_task = upstream_task
|
||||
for t in tasks:
|
||||
# Connect the task to the upstream task.
|
||||
last_task >> t
|
||||
last_task = t
|
||||
return dag
|
||||
|
||||
def pre_load_requirements(self, flow_panel: FlowPanel):
|
||||
@@ -950,6 +996,23 @@ def _topological_sort(
|
||||
return key_to_order
|
||||
|
||||
|
||||
def _build_mapper_operators(dag: DAG, mappers: List[str]) -> List[DAGNode]:
|
||||
from .base import _get_type_cls
|
||||
|
||||
tasks = []
|
||||
for mapper in mappers:
|
||||
try:
|
||||
mapper_cls = _get_type_cls(mapper)
|
||||
task = mapper_cls()
|
||||
if not task._node_id:
|
||||
task.set_node_id(dag._new_node_id())
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
err_msg = f"Unable to build mapper task: {mapper}, error: {e}"
|
||||
raise FlowMetadataException(err_msg)
|
||||
return tasks
|
||||
|
||||
|
||||
def fill_flow_panel(flow_panel: FlowPanel):
|
||||
"""Fill the flow panel with the latest metadata.
|
||||
|
||||
@@ -978,6 +1041,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
i.mappers = new_param.mappers
|
||||
for i in node.data.outputs:
|
||||
if i.name in output_parameters:
|
||||
new_param = output_parameters[i.name]
|
||||
@@ -986,6 +1050,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
i.mappers = new_param.mappers
|
||||
else:
|
||||
data = cast(ResourceMetadata, node.data)
|
||||
key = data.get_origin_id()
|
||||
|
@@ -945,6 +945,16 @@ class StringHttpTrigger(HttpTrigger):
|
||||
class CommonLLMHttpTrigger(HttpTrigger):
|
||||
"""Common LLM http trigger for AWEL."""
|
||||
|
||||
class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
"""Messages output mapper."""
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
|
||||
"""Map the request body to messages."""
|
||||
if isinstance(request_body.messages, str):
|
||||
return request_body.messages
|
||||
else:
|
||||
raise ValueError("Messages to be transformed is not a string")
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Common LLM Http Trigger"),
|
||||
name="common_llm_http_trigger",
|
||||
@@ -965,6 +975,16 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
"LLM http body"
|
||||
),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Request String Messages"),
|
||||
"request_string_messages",
|
||||
str,
|
||||
description=_(
|
||||
"The request string messages of the API endpoint, parsed from "
|
||||
"'messages' field of the request body"
|
||||
),
|
||||
mappers=[MessagesOutputMapper],
|
||||
),
|
||||
],
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
|
Reference in New Issue
Block a user