feat(core): Support dag scope variables

This commit is contained in:
Fangyin Cheng
2024-08-22 11:20:59 +08:00
parent 1c7a6c9122
commit de702ef8e6
12 changed files with 526 additions and 57 deletions

View File

@@ -347,9 +347,10 @@ CREATE TABLE `dbgpt_serve_variables` (
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
`description` text DEFAULT NULL COMMENT 'Variable description',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',

View File

@@ -101,9 +101,10 @@ CREATE TABLE `dbgpt_serve_variables` (
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
`description` text DEFAULT NULL COMMENT 'Variable description',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',

View File

@@ -20,3 +20,11 @@ del load_dotenv
TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE = "knowledge_factory_domain_type"
TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE = "knowledge_chat_domain_type"
DOMAIN_TYPE_FINANCIAL_REPORT = "FinancialReport"
VARIABLES_SCOPE_GLOBAL = "global"
VARIABLES_SCOPE_APP = "app"
VARIABLES_SCOPE_AGENT = "agent"
VARIABLES_SCOPE_FLOW = "flow"
VARIABLES_SCOPE_DATASOURCE = "datasource"
VARIABLES_SCOPE_FLOW_PRIVATE = "flow_priv"
VARIABLES_SCOPE_AGENT_PRIVATE = "agent_priv"

View File

@@ -5,6 +5,7 @@ DAG is the core component of AWEL, it is used to define the relationship between
import asyncio
import contextvars
import dataclasses
import logging
import threading
import uuid
@@ -17,6 +18,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
@@ -489,6 +491,100 @@ def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
@dataclasses.dataclass
class _DAGVariablesItem:
"""The DAG variables item.
It is a private class, just used for internal.
"""
key: str
name: str
label: str
value: Any
category: Literal["common", "secret"] = "common"
scope: str = "global"
value_type: Optional[str] = None
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
description: Optional[str] = None
@dataclasses.dataclass
class DAGVariables:
"""The DAG variables."""
items: List[_DAGVariablesItem] = dataclasses.field(default_factory=list)
_cached_provider: Optional["VariablesProvider"] = None
_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
def merge(self, dag_variables: "DAGVariables") -> "DAGVariables":
"""Merge the DAG variables.
Args:
dag_variables (DAGVariables): The DAG variables to merge
"""
def _build_key(item: _DAGVariablesItem):
key = "_".join([item.key, item.name, item.scope])
if item.scope_key:
key += f"_{item.scope_key}"
if item.sys_code:
key += f"_{item.sys_code}"
if item.user_name:
key += f"_{item.user_name}"
return key
new_items = []
exist_vars = set()
for item in self.items:
new_items.append(item)
exist_vars.add(_build_key(item))
for item in dag_variables.items:
key = _build_key(item)
if key not in exist_vars:
new_items.append(item)
return DAGVariables(
items=new_items,
_cached_provider=self._cached_provider or dag_variables._cached_provider,
)
def to_provider(self) -> "VariablesProvider":
"""Convert the DAG variables to variables provider.
Returns:
VariablesProvider: The variables provider
"""
if not self._cached_provider:
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
)
with self._lock:
# Create a new provider safely
provider = StorageVariablesProvider()
for item in self.items:
storage_vars = StorageVariables(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
category=item.category,
scope=item.scope,
value_type=item.value_type,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
provider.save(storage_vars)
self._cached_provider = provider
return self._cached_provider
class DAGContext:
"""The context of current DAG, created when the DAG is running.
@@ -502,6 +598,7 @@ class DAGContext:
event_loop_task_id: int,
streaming_call: bool = False,
node_name_to_ids: Optional[Dict[str, str]] = None,
dag_variables: Optional[DAGVariables] = None,
) -> None:
"""Initialize a DAGContext.
@@ -511,6 +608,7 @@ class DAGContext:
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
dag_variables (Optional[DAGVariables], optional): The DAG variables.
"""
if not node_name_to_ids:
node_name_to_ids = {}
@@ -520,6 +618,7 @@ class DAGContext:
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
self._event_loop_task_id = event_loop_task_id
self._dag_variables = dag_variables
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
@@ -653,6 +752,7 @@ class DAG:
resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
default_dag_variables: Optional[DAGVariables] = None,
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
@@ -666,6 +766,7 @@ class DAG:
self._resource_group: Optional[ResourceGroup] = resource_group
self._lock = asyncio.Lock()
self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {}
self._default_dag_variables = default_dag_variables
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:

View File

@@ -17,6 +17,7 @@ from dbgpt._private.pydantic import (
model_to_dict,
model_validator,
)
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.core.awel.dag.base import DAG, DAGNode
from dbgpt.core.awel.dag.dag_manager import DAGMetadata
@@ -166,29 +167,23 @@ class FlowData(BaseModel):
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
class VariablesRequest(BaseModel):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
class _VariablesRequestBase(BaseModel):
key: str = Field(
...,
description="The key of the variable to create",
examples=["dbgpt.model.openai.api_key"],
)
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
label: str = Field(
...,
description="The label of the variable to create",
examples=["My First OpenAI Key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
description: Optional[str] = Field(
None,
description="The description of the variable to create",
examples=["Your OpenAI API key"],
)
value_type: Literal["str", "int", "float", "bool"] = Field(
"str",
@@ -206,10 +201,26 @@ class VariablesRequest(BaseModel):
examples=["global"],
)
scope_key: Optional[str] = Field(
...,
None,
description="The scope key of the variable to create",
examples=["dbgpt"],
)
class VariablesRequest(_VariablesRequestBase):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
)
enabled: Optional[bool] = Field(
True,
description="Whether the variable is enabled",
@@ -219,6 +230,80 @@ class VariablesRequest(BaseModel):
sys_code: Optional[str] = Field(None, description="System code")
class ParsedFlowVariables(BaseModel):
"""Parsed variables for the flow."""
key: str = Field(
...,
description="The key of the variable",
examples=["dbgpt.model.openai.api_key"],
)
name: Optional[str] = Field(
None,
description="The name of the variable",
examples=["my_first_openai_key"],
)
scope: str = Field(
...,
description="The scope of the variable",
examples=["global"],
)
scope_key: Optional[str] = Field(
None,
description="The scope key of the variable",
examples=["dbgpt"],
)
sys_code: Optional[str] = Field(None, description="System code")
user_name: Optional[str] = Field(None, description="User name")
class FlowVariables(_VariablesRequestBase):
"""Variables for the flow."""
name: Optional[str] = Field(
None,
description="The name of the variable",
examples=["my_first_openai_key"],
)
value: Optional[Any] = Field(
None, description="The value of the variable", examples=["1234567890"]
)
parsed_variables: Optional[ParsedFlowVariables] = Field(
None, description="The parsed variables, parsed from the value"
)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "parsed_variables" not in values:
parsed_variables = cls.parse_value_to_variables(values.get("value"))
if parsed_variables:
values["parsed_variables"] = parsed_variables
return values
@classmethod
def parse_value_to_variables(cls, value: Any) -> Optional[ParsedFlowVariables]:
"""Parse the value to variables.
Args:
value (Any): The value to parse
Returns:
Optional[ParsedFlowVariables]: The parsed variables, None if the value is
invalid
"""
from ...interface.variables import _is_variable_format, parse_variable
if not value or not isinstance(value, str) or not _is_variable_format(value):
return None
variable_dict = parse_variable(value)
return ParsedFlowVariables(**variable_dict)
class State(str, Enum):
"""State of a flow panel."""
@@ -409,7 +494,7 @@ class FlowPanel(BaseModel):
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
default=None, description="The metadata of the flow"
)
variables: Optional[List[VariablesRequest]] = Field(
variables: Optional[List[FlowVariables]] = Field(
default=None, description="The variables of the flow"
)
authors: Optional[List[str]] = Field(
@@ -437,6 +522,21 @@ class FlowPanel(BaseModel):
"""Convert to dict."""
return model_to_dict(self, exclude={"flow_dag"})
def get_variables_dict(self) -> List[Dict[str, Any]]:
"""Get the variables dict."""
if not self.variables:
return []
return [v.dict() for v in self.variables]
@classmethod
def parse_variables(
cls, variables: Optional[List[Dict[str, Any]]] = None
) -> Optional[List[FlowVariables]]:
"""Parse the variables."""
if not variables:
return None
return [FlowVariables(**v) for v in variables]
class FlowFactory:
"""Flow factory."""
@@ -657,10 +757,36 @@ class FlowFactory:
dag_id: Optional[str] = None,
) -> DAG:
"""Build the DAG."""
from ..dag.base import DAGVariables, _DAGVariablesItem
formatted_name = flow_panel.name.replace(" ", "_")
if not dag_id:
dag_id = f"{self._dag_prefix}_{formatted_name}_{flow_panel.uid}"
with DAG(dag_id) as dag:
default_dag_variables: Optional[DAGVariables] = None
if flow_panel.variables:
variables = []
for v in flow_panel.variables:
scope_key = v.scope_key
if v.scope == VARIABLES_SCOPE_FLOW_PRIVATE and not scope_key:
scope_key = dag_id
variables.append(
_DAGVariablesItem(
key=v.key,
name=v.name, # type: ignore
label=v.label,
description=v.description,
value_type=v.value_type,
category=v.category,
scope=v.scope,
scope_key=scope_key,
value=v.value,
user_name=flow_panel.user_name,
sys_code=flow_panel.sys_code,
)
)
default_dag_variables = DAGVariables(items=variables)
with DAG(dag_id, default_dag_variables=default_dag_variables) as dag:
for key, task in key_to_tasks.items():
if not task._node_id:
task.set_node_id(dag._new_node_id())

View File

@@ -3,6 +3,7 @@ from typing import cast
import pytest
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
from dbgpt.core.awel.flow import (
IOField,
@@ -12,7 +13,12 @@ from dbgpt.core.awel.flow import (
ViewMetadata,
ui,
)
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel
from dbgpt.core.awel.flow.flow_factory import (
FlowData,
FlowFactory,
FlowPanel,
FlowVariables,
)
from ...tests.conftest import variables_provider
@@ -46,6 +52,28 @@ class MyVariablesOperator(MapOperator[str, str]):
key="dbgpt.model.openai.model",
),
),
Parameter.build_from(
"DAG Var 1",
"dag_var1",
type=str,
placeholder="Please select the DAG variable 1",
description="The DAG variable 1.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
),
),
Parameter.build_from(
"DAG Var 2",
"dag_var2",
type=str,
placeholder="Please select the DAG variable 2",
description="The DAG variable 2.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
),
),
],
inputs=[
IOField.build_from(
@@ -65,15 +93,21 @@ class MyVariablesOperator(MapOperator[str, str]):
],
)
def __init__(self, openai_api_key: str, model: str, **kwargs):
def __init__(
self, openai_api_key: str, model: str, dag_var1: str, dag_var2: str, **kwargs
):
super().__init__(**kwargs)
self._openai_api_key = openai_api_key
self._model = model
self._dag_var1 = dag_var1
self._dag_var2 = dag_var2
async def map(self, user_name: str) -> str:
dict_dict = {
"openai_api_key": self._openai_api_key,
"model": self._model,
"dag_var1": self._dag_var1,
"dag_var2": self._dag_var2,
}
json_data = json.dumps(dict_dict, ensure_ascii=False)
return "Your name is %s, and your model info is %s." % (user_name, json_data)
@@ -117,6 +151,10 @@ def json_flow():
"my_test_variables_operator": {
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
"model": "${dbgpt.model.openai.model:default_model@global}",
"dag_var1": "${dbgpt.core.flow.params:name1@%s}"
% VARIABLES_SCOPE_FLOW_PRIVATE,
"dag_var2": "${dbgpt.core.flow.params:name2@%s}"
% VARIABLES_SCOPE_FLOW_PRIVATE,
}
}
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
@@ -208,16 +246,49 @@ def json_flow():
async def test_build_flow(json_flow, variables_provider):
DAGVar.set_variables_provider(variables_provider)
flow_data = FlowData(**json_flow)
variables = [
FlowVariables(
key="dbgpt.core.flow.params",
name="name1",
label="Name 1",
value="value1",
value_type="str",
category="common",
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
# scope_key="my_test_flow",
),
FlowVariables(
key="dbgpt.core.flow.params",
name="name2",
label="Name 2",
value="value2",
value_type="str",
category="common",
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
# scope_key="my_test_flow",
),
]
flow_panel = FlowPanel(
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed"
label="My Test Flow",
name="my_test_flow",
flow_data=flow_data,
state="deployed",
variables=variables,
)
factory = FlowFactory()
dag = factory.build(flow_panel)
leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
result = await leaf_node.call("Alice")
expected_dict = {
"openai_api_key": "my_openai_api_key",
"model": "GPT-4o",
"dag_var1": "value1",
"dag_var2": "value2",
}
expected_dict_str = json.dumps(expected_dict, ensure_ascii=False)
assert (
result
== "End operator received input: Your name is Alice, and your model info is "
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.'
== f"End operator received input: Your name is Alice, and your model info is "
f"{expected_dict_str}."
)

View File

@@ -20,6 +20,7 @@ from typing import (
)
from dbgpt.component import ComponentType, SystemApp
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.util.executor_utils import (
AsyncToSyncIterator,
BlockingFunction,
@@ -28,7 +29,7 @@ from dbgpt.util.executor_utils import (
)
from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar, DAGVariables
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data
if TYPE_CHECKING:
@@ -58,6 +59,7 @@ class WorkflowRunner(ABC, Generic[T]):
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> DAGContext:
"""Execute the workflow starting from a given operator.
@@ -67,6 +69,7 @@ class WorkflowRunner(ABC, Generic[T]):
streaming_call (bool): Whether the call is a streaming call.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables.
Returns:
DAGContext: The context after executing the workflow, containing the final
state and data.
@@ -243,6 +246,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
self,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> OUT:
"""Execute the node and return the output.
@@ -252,6 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
Returns:
OUT: The output of the node after execution.
"""
@@ -259,13 +264,15 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data = {"data": call_data}
with root_tracer.start_span("dbgpt.awel.operator.call"):
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
self, call_data, exist_dag_ctx=dag_ctx, dag_variables=dag_variables
)
return out_ctx.current_task_context.task_output.output
def _blocking_call(
self,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
@@ -275,7 +282,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
loop (asyncio.BaseEventLoop): The event loop to run the operator.
Returns:
OUT: The output of the node after execution.
"""
@@ -284,12 +294,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop:
loop = get_or_create_event_loop()
loop = cast(asyncio.BaseEventLoop, loop)
return loop.run_until_complete(self.call(call_data))
return loop.run_until_complete(self.call(call_data, dag_ctx, dag_variables))
async def call_stream(
self,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
@@ -299,7 +310,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
@@ -307,7 +318,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data = {"data": call_data}
with root_tracer.start_span("dbgpt.awel.operator.call_stream"):
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
self,
call_data,
streaming_call=True,
exist_dag_ctx=dag_ctx,
dag_variables=dag_variables,
)
task_output = out_ctx.current_task_context.task_output
@@ -328,6 +343,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def _blocking_call_stream(
self,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
@@ -337,7 +354,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
loop (asyncio.BaseEventLoop): The event loop to run the operator.
Returns:
Iterator[OUT]: An iterator over the output stream.
"""
@@ -345,7 +365,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop:
loop = get_or_create_event_loop()
return AsyncToSyncIterator(self.call_stream(call_data), loop)
return AsyncToSyncIterator(
self.call_stream(call_data, dag_ctx, dag_variables), loop
)
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
@@ -373,20 +395,77 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Check if the operator can be skipped in the branch."""
return self._can_skip_in_branch
async def _resolve_variables(self, _: DAGContext):
from ...interface.variables import VariablesPlaceHolder
async def _resolve_variables(self, dag_ctx: DAGContext):
"""Resolve variables in the operator.
Some attributes of the operator may be VariablesPlaceHolder, which need to be
resolved before the operator is executed.
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
"""
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
if not self._variables_provider:
return
if dag_ctx._dag_variables:
# Resolve variables in DAG context
resolve_tasks = []
resolve_items = []
for item in dag_ctx._dag_variables.items:
# TODO: Resolve variables just once?
if isinstance(item.value, VariablesPlaceHolder):
resolve_tasks.append(
self.blocking_func_to_async(
item.value.parse, self._variables_provider
)
)
resolve_items.append(item)
resolved_values = await asyncio.gather(*resolve_tasks)
for item, rv in zip(resolve_items, resolved_values):
item.value = rv
dag_provider: Optional["VariablesProvider"] = None
if dag_ctx._dag_variables:
dag_provider = dag_ctx._dag_variables.to_provider()
# TODO: Resolve variables parallel
for attr, value in self.__dict__.items():
# Handle all attributes that are VariablesPlaceHolder
if isinstance(value, VariablesPlaceHolder):
resolved_value = await self.blocking_func_to_async(
value.parse, self._variables_provider
)
logger.debug(
f"Resolve variable {attr} with value {resolved_value} for {self}"
)
resolved_value: Any = None
default_identifier_map = None
id_key = VariablesIdentifier.from_str_identifier(value.full_key)
if (
id_key.scope == VARIABLES_SCOPE_FLOW_PRIVATE
and id_key.scope_key is None
and self.dag
):
default_identifier_map = {"scope_key": self.dag.dag_id}
if dag_provider:
# First try to resolve the variable with the DAG variables
resolved_value = await self.blocking_func_to_async(
value.parse,
dag_provider,
ignore_not_found_error=True,
default_identifier_map=default_identifier_map,
)
if resolved_value is None:
resolved_value = await self.blocking_func_to_async(
value.parse,
self._variables_provider,
default_identifier_map=default_identifier_map,
)
logger.debug(
f"Resolve variable {attr} with value {resolved_value} for "
f"{self} from system variables"
)
else:
logger.debug(
f"Resolve variable {attr} with value {resolved_value} for "
f"{self} from DAG variables"
)
setattr(self, attr, resolved_value)

View File

@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
from dbgpt.util.tracer import root_tracer
from ..dag.base import DAGContext, DAGVar
from ..dag.base import DAGContext, DAGVar, DAGVariables
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operators.common_operator import BranchOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
@@ -46,6 +46,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> DAGContext:
"""Execute the workflow.
@@ -57,6 +58,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None.
dag_variables (Optional[DAGVariables], optional): The DAG variables.
"""
# Save node output
# dag = node.dag
@@ -71,12 +73,19 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data
event_loop_task_id = exist_dag_ctx._event_loop_task_id
if dag_variables and exist_dag_ctx._dag_variables:
# Merge dag variables, prefer the `dag_variables` in the parameter
dag_variables = dag_variables.merge(exist_dag_ctx._dag_variables)
if node.dag and not dag_variables and node.dag._default_dag_variables:
# Use default dag variables if not set
dag_variables = node.dag._default_dag_variables
dag_ctx = DAGContext(
event_loop_task_id=event_loop_task_id,
node_to_outputs=node_outputs,
share_data=share_data,
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
dag_variables=dag_variables,
)
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx

View File

@@ -212,11 +212,18 @@ class VariablesIdentifier(ResourceIdentifier):
}
@classmethod
def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier":
def from_str_identifier(
cls,
str_identifier: str,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> "VariablesIdentifier":
"""Create a VariablesIdentifier from a string identifier.
Args:
str_identifier (str): The string identifier.
default_identifier_map (Optional[Dict[str, str]]): The default identifier
map, which contains the default values for the identifier. Defaults to
None.
Returns:
VariablesIdentifier: The VariablesIdentifier.
@@ -229,13 +236,20 @@ class VariablesIdentifier(ResourceIdentifier):
if not variable_dict.get("name"):
raise ValueError("Invalid string identifier, must have name")
def _get_value(key, default_value: Optional[str] = None) -> Optional[str]:
if variable_dict.get(key) is not None:
return variable_dict.get(key)
if default_identifier_map is not None and default_identifier_map.get(key):
return default_identifier_map.get(key)
return default_value
return cls(
key=variable_dict["key"],
name=variable_dict["name"],
scope=variable_dict.get("scope", "global"),
scope_key=variable_dict.get("scope_key"),
sys_code=variable_dict.get("sys_code"),
user_name=variable_dict.get("user_name"),
scope=variable_dict["scope"],
scope_key=_get_value("scope_key"),
sys_code=_get_value("sys_code"),
user_name=_get_value("user_name"),
)
@@ -256,6 +270,7 @@ class StorageVariables(StorageItem):
encryption_method: Optional[str] = None
salt: Optional[str] = None
enabled: int = 1
description: Optional[str] = None
_identifier: VariablesIdentifier = dataclasses.field(init=False)
@@ -297,6 +312,8 @@ class StorageVariables(StorageItem):
"category": self.category,
"encryption_method": self.encryption_method,
"salt": self.salt,
"enabled": self.enabled,
"description": self.description,
}
def from_object(self, other: "StorageVariables") -> None:
@@ -311,6 +328,8 @@ class StorageVariables(StorageItem):
self.user_name = other.user_name
self.encryption_method = other.encryption_method
self.salt = other.salt
self.enabled = other.enabled
self.description = other.description
@classmethod
def from_identifier(
@@ -347,7 +366,10 @@ class VariablesProvider(BaseComponent, ABC):
@abstractmethod
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
@@ -416,9 +438,23 @@ class VariablesPlaceHolder:
self.full_key = full_key
self.default_value = default_value
def parse(self, variables_provider: VariablesProvider) -> Any:
def parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables."""
return variables_provider.get(self.full_key, self.default_value)
try:
return variables_provider.get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e
def __repr__(self):
"""Return the representation of the variables place holder."""
@@ -449,10 +485,13 @@ class StorageVariablesProvider(VariablesProvider):
self.system_app = system_app
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
key = VariablesIdentifier.from_str_identifier(full_key)
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
if variable is None:
if default_value == _EMPTY_DEFAULT_VALUE:
@@ -641,7 +680,10 @@ class BuiltinVariablesProvider(VariablesProvider, ABC):
self.system_app = system_app
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support get.")

View File

@@ -49,6 +49,7 @@ class ServeEntity(Model):
editable = Column(
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
)
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -109,14 +110,14 @@ class VariablesEntity(Model):
String(32),
default="global",
nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow:uid,"
"flow:dag_name,agent:agent_name) etc",
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
"etc)",
)
scope_key = Column(
String(256),
nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow:uid', "
"the scope_key is uid of flow",
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
"the scope_key is dag id of flow",
)
enabled = Column(
Integer,
@@ -124,6 +125,7 @@ class VariablesEntity(Model):
nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled",
)
description = Column(Text, nullable=True, comment="Variable description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -154,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
error_message = request_dict.get("error_message")
if error_message:
error_message = error_message[:500]
variables_raw = request_dict.get("variables")
variables = (
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
)
new_dict = {
"uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"),
@@ -169,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"define_type": request_dict.get("define_type"),
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
"description": request_dict.get("description"),
"variables": variables,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
}
@@ -185,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
REQ: The request
"""
flow_data = json.loads(entity.flow_data)
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServeRequest(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -200,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
define_type=entity.define_type,
editable=ServeEntity.to_bool_editable(entity.editable),
description=entity.description,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -216,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_data = json.loads(entity.flow_data)
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServerResponse(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -231,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
version=entity.version,
editable=ServeEntity.to_bool_editable(entity.editable),
define_type=entity.define_type,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
@@ -271,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.define_type:
entry.define_type = update_request.define_type
if update_request.variables:
variables_raw = update_request.get_variables_dict()
entry.variables = (
json.dumps(variables_raw, ensure_ascii=False)
if variables_raw
else None
)
if update_request.user_name:
entry.user_name = update_request.user_name
if update_request.sys_code:
@@ -317,6 +339,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
"enabled": enabled,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
"description": request_dict.get("description"),
}
entity = VariablesEntity(**new_dict)
return entity
@@ -348,6 +371,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
description=entity.description,
)
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
@@ -382,4 +406,5 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
description=entity.description,
)

View File

@@ -29,6 +29,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
@@ -46,6 +47,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
scope_key=model.scope_key,
sys_code=model.sys_code,
user_name=model.user_name,
description=model.description,
)
def get_query_for_identifier(

View File

@@ -90,6 +90,8 @@ class VariablesService(
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
self.variables_provider.save(variables)
query = {
@@ -123,6 +125,8 @@ class VariablesService(
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
exist_value = self.variables_provider.get(
variables.identifier.str_identifier, None