feat(core): Support dag scope variables

This commit is contained in:
Fangyin Cheng
2024-08-22 11:20:59 +08:00
parent 1c7a6c9122
commit de702ef8e6
12 changed files with 526 additions and 57 deletions

View File

@@ -347,9 +347,10 @@ CREATE TABLE `dbgpt_serve_variables` (
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
`description` text DEFAULT NULL COMMENT 'Variable description',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',

View File

@@ -101,9 +101,10 @@ CREATE TABLE `dbgpt_serve_variables` (
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
`description` text DEFAULT NULL COMMENT 'Variable description',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',

View File

@@ -20,3 +20,11 @@ del load_dotenv
TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE = "knowledge_factory_domain_type" TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE = "knowledge_factory_domain_type"
TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE = "knowledge_chat_domain_type" TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE = "knowledge_chat_domain_type"
DOMAIN_TYPE_FINANCIAL_REPORT = "FinancialReport" DOMAIN_TYPE_FINANCIAL_REPORT = "FinancialReport"
VARIABLES_SCOPE_GLOBAL = "global"
VARIABLES_SCOPE_APP = "app"
VARIABLES_SCOPE_AGENT = "agent"
VARIABLES_SCOPE_FLOW = "flow"
VARIABLES_SCOPE_DATASOURCE = "datasource"
VARIABLES_SCOPE_FLOW_PRIVATE = "flow_priv"
VARIABLES_SCOPE_AGENT_PRIVATE = "agent_priv"

View File

@@ -5,6 +5,7 @@ DAG is the core component of AWEL, it is used to define the relationship between
import asyncio import asyncio
import contextvars import contextvars
import dataclasses
import logging import logging
import threading import threading
import uuid import uuid
@@ -17,6 +18,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Literal,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@@ -489,6 +491,100 @@ def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}" return f"{task_name}___$$$$$$___{key}"
@dataclasses.dataclass
class _DAGVariablesItem:
"""The DAG variables item.
It is a private class, just used for internal.
"""
key: str
name: str
label: str
value: Any
category: Literal["common", "secret"] = "common"
scope: str = "global"
value_type: Optional[str] = None
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
description: Optional[str] = None
@dataclasses.dataclass
class DAGVariables:
"""The DAG variables."""
items: List[_DAGVariablesItem] = dataclasses.field(default_factory=list)
_cached_provider: Optional["VariablesProvider"] = None
_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
def merge(self, dag_variables: "DAGVariables") -> "DAGVariables":
"""Merge the DAG variables.
Args:
dag_variables (DAGVariables): The DAG variables to merge
"""
def _build_key(item: _DAGVariablesItem):
key = "_".join([item.key, item.name, item.scope])
if item.scope_key:
key += f"_{item.scope_key}"
if item.sys_code:
key += f"_{item.sys_code}"
if item.user_name:
key += f"_{item.user_name}"
return key
new_items = []
exist_vars = set()
for item in self.items:
new_items.append(item)
exist_vars.add(_build_key(item))
for item in dag_variables.items:
key = _build_key(item)
if key not in exist_vars:
new_items.append(item)
return DAGVariables(
items=new_items,
_cached_provider=self._cached_provider or dag_variables._cached_provider,
)
def to_provider(self) -> "VariablesProvider":
"""Convert the DAG variables to variables provider.
Returns:
VariablesProvider: The variables provider
"""
if not self._cached_provider:
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
)
with self._lock:
# Create a new provider safely
provider = StorageVariablesProvider()
for item in self.items:
storage_vars = StorageVariables(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
category=item.category,
scope=item.scope,
value_type=item.value_type,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
provider.save(storage_vars)
self._cached_provider = provider
return self._cached_provider
class DAGContext: class DAGContext:
"""The context of current DAG, created when the DAG is running. """The context of current DAG, created when the DAG is running.
@@ -502,6 +598,7 @@ class DAGContext:
event_loop_task_id: int, event_loop_task_id: int,
streaming_call: bool = False, streaming_call: bool = False,
node_name_to_ids: Optional[Dict[str, str]] = None, node_name_to_ids: Optional[Dict[str, str]] = None,
dag_variables: Optional[DAGVariables] = None,
) -> None: ) -> None:
"""Initialize a DAGContext. """Initialize a DAGContext.
@@ -511,6 +608,7 @@ class DAGContext:
streaming_call (bool, optional): Whether the current DAG is streaming call. streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False. Defaults to False.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
dag_variables (Optional[DAGVariables], optional): The DAG variables.
""" """
if not node_name_to_ids: if not node_name_to_ids:
node_name_to_ids = {} node_name_to_ids = {}
@@ -520,6 +618,7 @@ class DAGContext:
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids self._node_name_to_ids: Dict[str, str] = node_name_to_ids
self._event_loop_task_id = event_loop_task_id self._event_loop_task_id = event_loop_task_id
self._dag_variables = dag_variables
@property @property
def _task_outputs(self) -> Dict[str, TaskContext]: def _task_outputs(self) -> Dict[str, TaskContext]:
@@ -653,6 +752,7 @@ class DAG:
resource_group: Optional[ResourceGroup] = None, resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
description: Optional[str] = None, description: Optional[str] = None,
default_dag_variables: Optional[DAGVariables] = None,
) -> None: ) -> None:
"""Initialize a DAG.""" """Initialize a DAG."""
self._dag_id = dag_id self._dag_id = dag_id
@@ -666,6 +766,7 @@ class DAG:
self._resource_group: Optional[ResourceGroup] = resource_group self._resource_group: Optional[ResourceGroup] = resource_group
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {} self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {}
self._default_dag_variables = default_dag_variables
def _append_node(self, node: DAGNode) -> None: def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map: if node.node_id in self.node_map:

View File

@@ -17,6 +17,7 @@ from dbgpt._private.pydantic import (
model_to_dict, model_to_dict,
model_validator, model_validator,
) )
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.core.awel.dag.base import DAG, DAGNode from dbgpt.core.awel.dag.base import DAG, DAGNode
from dbgpt.core.awel.dag.dag_manager import DAGMetadata from dbgpt.core.awel.dag.dag_manager import DAGMetadata
@@ -166,29 +167,23 @@ class FlowData(BaseModel):
viewport: FlowPositionData = Field(..., description="Viewport of the flow") viewport: FlowPositionData = Field(..., description="Viewport of the flow")
class VariablesRequest(BaseModel): class _VariablesRequestBase(BaseModel):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
key: str = Field( key: str = Field(
..., ...,
description="The key of the variable to create", description="The key of the variable to create",
examples=["dbgpt.model.openai.api_key"], examples=["dbgpt.model.openai.api_key"],
) )
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
label: str = Field( label: str = Field(
..., ...,
description="The label of the variable to create", description="The label of the variable to create",
examples=["My First OpenAI Key"], examples=["My First OpenAI Key"],
) )
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"] description: Optional[str] = Field(
None,
description="The description of the variable to create",
examples=["Your OpenAI API key"],
) )
value_type: Literal["str", "int", "float", "bool"] = Field( value_type: Literal["str", "int", "float", "bool"] = Field(
"str", "str",
@@ -206,10 +201,26 @@ class VariablesRequest(BaseModel):
examples=["global"], examples=["global"],
) )
scope_key: Optional[str] = Field( scope_key: Optional[str] = Field(
..., None,
description="The scope key of the variable to create", description="The scope key of the variable to create",
examples=["dbgpt"], examples=["dbgpt"],
) )
class VariablesRequest(_VariablesRequestBase):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
)
enabled: Optional[bool] = Field( enabled: Optional[bool] = Field(
True, True,
description="Whether the variable is enabled", description="Whether the variable is enabled",
@@ -219,6 +230,80 @@ class VariablesRequest(BaseModel):
sys_code: Optional[str] = Field(None, description="System code") sys_code: Optional[str] = Field(None, description="System code")
class ParsedFlowVariables(BaseModel):
"""Parsed variables for the flow."""
key: str = Field(
...,
description="The key of the variable",
examples=["dbgpt.model.openai.api_key"],
)
name: Optional[str] = Field(
None,
description="The name of the variable",
examples=["my_first_openai_key"],
)
scope: str = Field(
...,
description="The scope of the variable",
examples=["global"],
)
scope_key: Optional[str] = Field(
None,
description="The scope key of the variable",
examples=["dbgpt"],
)
sys_code: Optional[str] = Field(None, description="System code")
user_name: Optional[str] = Field(None, description="User name")
class FlowVariables(_VariablesRequestBase):
"""Variables for the flow."""
name: Optional[str] = Field(
None,
description="The name of the variable",
examples=["my_first_openai_key"],
)
value: Optional[Any] = Field(
None, description="The value of the variable", examples=["1234567890"]
)
parsed_variables: Optional[ParsedFlowVariables] = Field(
None, description="The parsed variables, parsed from the value"
)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "parsed_variables" not in values:
parsed_variables = cls.parse_value_to_variables(values.get("value"))
if parsed_variables:
values["parsed_variables"] = parsed_variables
return values
@classmethod
def parse_value_to_variables(cls, value: Any) -> Optional[ParsedFlowVariables]:
"""Parse the value to variables.
Args:
value (Any): The value to parse
Returns:
Optional[ParsedFlowVariables]: The parsed variables, None if the value is
invalid
"""
from ...interface.variables import _is_variable_format, parse_variable
if not value or not isinstance(value, str) or not _is_variable_format(value):
return None
variable_dict = parse_variable(value)
return ParsedFlowVariables(**variable_dict)
class State(str, Enum): class State(str, Enum):
"""State of a flow panel.""" """State of a flow panel."""
@@ -409,7 +494,7 @@ class FlowPanel(BaseModel):
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field( metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
default=None, description="The metadata of the flow" default=None, description="The metadata of the flow"
) )
variables: Optional[List[VariablesRequest]] = Field( variables: Optional[List[FlowVariables]] = Field(
default=None, description="The variables of the flow" default=None, description="The variables of the flow"
) )
authors: Optional[List[str]] = Field( authors: Optional[List[str]] = Field(
@@ -437,6 +522,21 @@ class FlowPanel(BaseModel):
"""Convert to dict.""" """Convert to dict."""
return model_to_dict(self, exclude={"flow_dag"}) return model_to_dict(self, exclude={"flow_dag"})
def get_variables_dict(self) -> List[Dict[str, Any]]:
"""Get the variables dict."""
if not self.variables:
return []
return [v.dict() for v in self.variables]
@classmethod
def parse_variables(
cls, variables: Optional[List[Dict[str, Any]]] = None
) -> Optional[List[FlowVariables]]:
"""Parse the variables."""
if not variables:
return None
return [FlowVariables(**v) for v in variables]
class FlowFactory: class FlowFactory:
"""Flow factory.""" """Flow factory."""
@@ -657,10 +757,36 @@ class FlowFactory:
dag_id: Optional[str] = None, dag_id: Optional[str] = None,
) -> DAG: ) -> DAG:
"""Build the DAG.""" """Build the DAG."""
from ..dag.base import DAGVariables, _DAGVariablesItem
formatted_name = flow_panel.name.replace(" ", "_") formatted_name = flow_panel.name.replace(" ", "_")
if not dag_id: if not dag_id:
dag_id = f"{self._dag_prefix}_{formatted_name}_{flow_panel.uid}" dag_id = f"{self._dag_prefix}_{formatted_name}_{flow_panel.uid}"
with DAG(dag_id) as dag:
default_dag_variables: Optional[DAGVariables] = None
if flow_panel.variables:
variables = []
for v in flow_panel.variables:
scope_key = v.scope_key
if v.scope == VARIABLES_SCOPE_FLOW_PRIVATE and not scope_key:
scope_key = dag_id
variables.append(
_DAGVariablesItem(
key=v.key,
name=v.name, # type: ignore
label=v.label,
description=v.description,
value_type=v.value_type,
category=v.category,
scope=v.scope,
scope_key=scope_key,
value=v.value,
user_name=flow_panel.user_name,
sys_code=flow_panel.sys_code,
)
)
default_dag_variables = DAGVariables(items=variables)
with DAG(dag_id, default_dag_variables=default_dag_variables) as dag:
for key, task in key_to_tasks.items(): for key, task in key_to_tasks.items():
if not task._node_id: if not task._node_id:
task.set_node_id(dag._new_node_id()) task.set_node_id(dag._new_node_id())

View File

@@ -3,6 +3,7 @@ from typing import cast
import pytest import pytest
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
from dbgpt.core.awel.flow import ( from dbgpt.core.awel.flow import (
IOField, IOField,
@@ -12,7 +13,12 @@ from dbgpt.core.awel.flow import (
ViewMetadata, ViewMetadata,
ui, ui,
) )
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel from dbgpt.core.awel.flow.flow_factory import (
FlowData,
FlowFactory,
FlowPanel,
FlowVariables,
)
from ...tests.conftest import variables_provider from ...tests.conftest import variables_provider
@@ -46,6 +52,28 @@ class MyVariablesOperator(MapOperator[str, str]):
key="dbgpt.model.openai.model", key="dbgpt.model.openai.model",
), ),
), ),
Parameter.build_from(
"DAG Var 1",
"dag_var1",
type=str,
placeholder="Please select the DAG variable 1",
description="The DAG variable 1.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
),
),
Parameter.build_from(
"DAG Var 2",
"dag_var2",
type=str,
placeholder="Please select the DAG variable 2",
description="The DAG variable 2.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
),
),
], ],
inputs=[ inputs=[
IOField.build_from( IOField.build_from(
@@ -65,15 +93,21 @@ class MyVariablesOperator(MapOperator[str, str]):
], ],
) )
def __init__(self, openai_api_key: str, model: str, **kwargs): def __init__(
self, openai_api_key: str, model: str, dag_var1: str, dag_var2: str, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self._openai_api_key = openai_api_key self._openai_api_key = openai_api_key
self._model = model self._model = model
self._dag_var1 = dag_var1
self._dag_var2 = dag_var2
async def map(self, user_name: str) -> str: async def map(self, user_name: str) -> str:
dict_dict = { dict_dict = {
"openai_api_key": self._openai_api_key, "openai_api_key": self._openai_api_key,
"model": self._model, "model": self._model,
"dag_var1": self._dag_var1,
"dag_var2": self._dag_var2,
} }
json_data = json.dumps(dict_dict, ensure_ascii=False) json_data = json.dumps(dict_dict, ensure_ascii=False)
return "Your name is %s, and your model info is %s." % (user_name, json_data) return "Your name is %s, and your model info is %s." % (user_name, json_data)
@@ -117,6 +151,10 @@ def json_flow():
"my_test_variables_operator": { "my_test_variables_operator": {
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}", "openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
"model": "${dbgpt.model.openai.model:default_model@global}", "model": "${dbgpt.model.openai.model:default_model@global}",
"dag_var1": "${dbgpt.core.flow.params:name1@%s}"
% VARIABLES_SCOPE_FLOW_PRIVATE,
"dag_var2": "${dbgpt.core.flow.params:name2@%s}"
% VARIABLES_SCOPE_FLOW_PRIVATE,
} }
} }
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list} name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
@@ -208,16 +246,49 @@ def json_flow():
async def test_build_flow(json_flow, variables_provider): async def test_build_flow(json_flow, variables_provider):
DAGVar.set_variables_provider(variables_provider) DAGVar.set_variables_provider(variables_provider)
flow_data = FlowData(**json_flow) flow_data = FlowData(**json_flow)
variables = [
FlowVariables(
key="dbgpt.core.flow.params",
name="name1",
label="Name 1",
value="value1",
value_type="str",
category="common",
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
# scope_key="my_test_flow",
),
FlowVariables(
key="dbgpt.core.flow.params",
name="name2",
label="Name 2",
value="value2",
value_type="str",
category="common",
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
# scope_key="my_test_flow",
),
]
flow_panel = FlowPanel( flow_panel = FlowPanel(
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed" label="My Test Flow",
name="my_test_flow",
flow_data=flow_data,
state="deployed",
variables=variables,
) )
factory = FlowFactory() factory = FlowFactory()
dag = factory.build(flow_panel) dag = factory.build(flow_panel)
leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0]) leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
result = await leaf_node.call("Alice") result = await leaf_node.call("Alice")
expected_dict = {
"openai_api_key": "my_openai_api_key",
"model": "GPT-4o",
"dag_var1": "value1",
"dag_var2": "value2",
}
expected_dict_str = json.dumps(expected_dict, ensure_ascii=False)
assert ( assert (
result result
== "End operator received input: Your name is Alice, and your model info is " == f"End operator received input: Your name is Alice, and your model info is "
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.' f"{expected_dict_str}."
) )

View File

@@ -20,6 +20,7 @@ from typing import (
) )
from dbgpt.component import ComponentType, SystemApp from dbgpt.component import ComponentType, SystemApp
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
from dbgpt.util.executor_utils import ( from dbgpt.util.executor_utils import (
AsyncToSyncIterator, AsyncToSyncIterator,
BlockingFunction, BlockingFunction,
@@ -28,7 +29,7 @@ from dbgpt.util.executor_utils import (
) )
from dbgpt.util.tracer import root_tracer from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar from ..dag.base import DAG, DAGContext, DAGNode, DAGVar, DAGVariables
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -58,6 +59,7 @@ class WorkflowRunner(ABC, Generic[T]):
call_data: Optional[CALL_DATA] = None, call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False, streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None, exist_dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> DAGContext: ) -> DAGContext:
"""Execute the workflow starting from a given operator. """Execute the workflow starting from a given operator.
@@ -67,6 +69,7 @@ class WorkflowRunner(ABC, Generic[T]):
streaming_call (bool): Whether the call is a streaming call. streaming_call (bool): Whether the call is a streaming call.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None. Defaults to None.
dag_variables (DAGVariables): The DAG variables.
Returns: Returns:
DAGContext: The context after executing the workflow, containing the final DAGContext: The context after executing the workflow, containing the final
state and data. state and data.
@@ -243,6 +246,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
self, self,
call_data: Optional[CALL_DATA] = EMPTY_DATA, call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None, dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> OUT: ) -> OUT:
"""Execute the node and return the output. """Execute the node and return the output.
@@ -252,6 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None. Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
Returns: Returns:
OUT: The output of the node after execution. OUT: The output of the node after execution.
""" """
@@ -259,13 +264,15 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data = {"data": call_data} call_data = {"data": call_data}
with root_tracer.start_span("dbgpt.awel.operator.call"): with root_tracer.start_span("dbgpt.awel.operator.call"):
out_ctx = await self._runner.execute_workflow( out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx self, call_data, exist_dag_ctx=dag_ctx, dag_variables=dag_variables
) )
return out_ctx.current_task_context.task_output.output return out_ctx.current_task_context.task_output.output
def _blocking_call( def _blocking_call(
self, self,
call_data: Optional[CALL_DATA] = EMPTY_DATA, call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
loop: Optional[asyncio.BaseEventLoop] = None, loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT: ) -> OUT:
"""Execute the node and return the output. """Execute the node and return the output.
@@ -275,7 +282,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args: Args:
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
loop (asyncio.BaseEventLoop): The event loop to run the operator.
Returns: Returns:
OUT: The output of the node after execution. OUT: The output of the node after execution.
""" """
@@ -284,12 +294,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop: if not loop:
loop = get_or_create_event_loop() loop = get_or_create_event_loop()
loop = cast(asyncio.BaseEventLoop, loop) loop = cast(asyncio.BaseEventLoop, loop)
return loop.run_until_complete(self.call(call_data)) return loop.run_until_complete(self.call(call_data, dag_ctx, dag_variables))
async def call_stream( async def call_stream(
self, self,
call_data: Optional[CALL_DATA] = EMPTY_DATA, call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None, dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> AsyncIterator[OUT]: ) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream. """Execute the node and return the output as a stream.
@@ -299,7 +310,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None. Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
Returns: Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream. AsyncIterator[OUT]: An asynchronous iterator over the output stream.
""" """
@@ -307,7 +318,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
call_data = {"data": call_data} call_data = {"data": call_data}
with root_tracer.start_span("dbgpt.awel.operator.call_stream"): with root_tracer.start_span("dbgpt.awel.operator.call_stream"):
out_ctx = await self._runner.execute_workflow( out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx self,
call_data,
streaming_call=True,
exist_dag_ctx=dag_ctx,
dag_variables=dag_variables,
) )
task_output = out_ctx.current_task_context.task_output task_output = out_ctx.current_task_context.task_output
@@ -328,6 +343,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def _blocking_call_stream( def _blocking_call_stream(
self, self,
call_data: Optional[CALL_DATA] = EMPTY_DATA, call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
loop: Optional[asyncio.BaseEventLoop] = None, loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]: ) -> Iterator[OUT]:
"""Execute the node and return the output as a stream. """Execute the node and return the output as a stream.
@@ -337,7 +354,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args: Args:
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
dag_variables (DAGVariables): The DAG variables passed to current DAG.
loop (asyncio.BaseEventLoop): The event loop to run the operator.
Returns: Returns:
Iterator[OUT]: An iterator over the output stream. Iterator[OUT]: An iterator over the output stream.
""" """
@@ -345,7 +365,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop: if not loop:
loop = get_or_create_event_loop() loop = get_or_create_event_loop()
return AsyncToSyncIterator(self.call_stream(call_data), loop) return AsyncToSyncIterator(
self.call_stream(call_data, dag_ctx, dag_variables), loop
)
async def blocking_func_to_async( async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs self, func: BlockingFunction, *args, **kwargs
@@ -373,19 +395,76 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Check if the operator can be skipped in the branch.""" """Check if the operator can be skipped in the branch."""
return self._can_skip_in_branch return self._can_skip_in_branch
async def _resolve_variables(self, _: DAGContext): async def _resolve_variables(self, dag_ctx: DAGContext):
from ...interface.variables import VariablesPlaceHolder """Resolve variables in the operator.
Some attributes of the operator may be VariablesPlaceHolder, which need to be
resolved before the operator is executed.
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
"""
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
if not self._variables_provider: if not self._variables_provider:
return return
if dag_ctx._dag_variables:
# Resolve variables in DAG context
resolve_tasks = []
resolve_items = []
for item in dag_ctx._dag_variables.items:
# TODO: Resolve variables just once?
if isinstance(item.value, VariablesPlaceHolder):
resolve_tasks.append(
self.blocking_func_to_async(
item.value.parse, self._variables_provider
)
)
resolve_items.append(item)
resolved_values = await asyncio.gather(*resolve_tasks)
for item, rv in zip(resolve_items, resolved_values):
item.value = rv
dag_provider: Optional["VariablesProvider"] = None
if dag_ctx._dag_variables:
dag_provider = dag_ctx._dag_variables.to_provider()
# TODO: Resolve variables parallel # TODO: Resolve variables parallel
for attr, value in self.__dict__.items(): for attr, value in self.__dict__.items():
# Handle all attributes that are VariablesPlaceHolder
if isinstance(value, VariablesPlaceHolder): if isinstance(value, VariablesPlaceHolder):
resolved_value: Any = None
default_identifier_map = None
id_key = VariablesIdentifier.from_str_identifier(value.full_key)
if (
id_key.scope == VARIABLES_SCOPE_FLOW_PRIVATE
and id_key.scope_key is None
and self.dag
):
default_identifier_map = {"scope_key": self.dag.dag_id}
if dag_provider:
# First try to resolve the variable with the DAG variables
resolved_value = await self.blocking_func_to_async( resolved_value = await self.blocking_func_to_async(
value.parse, self._variables_provider value.parse,
dag_provider,
ignore_not_found_error=True,
default_identifier_map=default_identifier_map,
)
if resolved_value is None:
resolved_value = await self.blocking_func_to_async(
value.parse,
self._variables_provider,
default_identifier_map=default_identifier_map,
) )
logger.debug( logger.debug(
f"Resolve variable {attr} with value {resolved_value} for {self}" f"Resolve variable {attr} with value {resolved_value} for "
f"{self} from system variables"
)
else:
logger.debug(
f"Resolve variable {attr} with value {resolved_value} for "
f"{self} from DAG variables"
) )
setattr(self, attr, resolved_value) setattr(self, attr, resolved_value)

View File

@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.util.tracer import root_tracer from dbgpt.util.tracer import root_tracer
from ..dag.base import DAGContext, DAGVar from ..dag.base import DAGContext, DAGVar, DAGVariables
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operators.common_operator import BranchOperator from ..operators.common_operator import BranchOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState from ..task.base import SKIP_DATA, TaskContext, TaskState
@@ -46,6 +46,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
call_data: Optional[CALL_DATA] = None, call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False, streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None, exist_dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> DAGContext: ) -> DAGContext:
"""Execute the workflow. """Execute the workflow.
@@ -57,6 +58,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
Defaults to False. Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context. exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None. Defaults to None.
dag_variables (Optional[DAGVariables], optional): The DAG variables.
""" """
# Save node output # Save node output
# dag = node.dag # dag = node.dag
@@ -71,12 +73,19 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs = exist_dag_ctx._node_to_outputs node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data share_data = exist_dag_ctx._share_data
event_loop_task_id = exist_dag_ctx._event_loop_task_id event_loop_task_id = exist_dag_ctx._event_loop_task_id
if dag_variables and exist_dag_ctx._dag_variables:
# Merge dag variables, prefer the `dag_variables` in the parameter
dag_variables = dag_variables.merge(exist_dag_ctx._dag_variables)
if node.dag and not dag_variables and node.dag._default_dag_variables:
# Use default dag variables if not set
dag_variables = node.dag._default_dag_variables
dag_ctx = DAGContext( dag_ctx = DAGContext(
event_loop_task_id=event_loop_task_id, event_loop_task_id=event_loop_task_id,
node_to_outputs=node_outputs, node_to_outputs=node_outputs,
share_data=share_data, share_data=share_data,
streaming_call=streaming_call, streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids, node_name_to_ids=job_manager._node_name_to_ids,
dag_variables=dag_variables,
) )
# if node.dag: # if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx # self._running_dag_ctx[node.dag.dag_id] = dag_ctx

View File

@@ -212,11 +212,18 @@ class VariablesIdentifier(ResourceIdentifier):
} }
@classmethod @classmethod
def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier": def from_str_identifier(
cls,
str_identifier: str,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> "VariablesIdentifier":
"""Create a VariablesIdentifier from a string identifier. """Create a VariablesIdentifier from a string identifier.
Args: Args:
str_identifier (str): The string identifier. str_identifier (str): The string identifier.
default_identifier_map (Optional[Dict[str, str]]): The default identifier
map, which contains the default values for the identifier. Defaults to
None.
Returns: Returns:
VariablesIdentifier: The VariablesIdentifier. VariablesIdentifier: The VariablesIdentifier.
@@ -229,13 +236,20 @@ class VariablesIdentifier(ResourceIdentifier):
if not variable_dict.get("name"): if not variable_dict.get("name"):
raise ValueError("Invalid string identifier, must have name") raise ValueError("Invalid string identifier, must have name")
def _get_value(key, default_value: Optional[str] = None) -> Optional[str]:
if variable_dict.get(key) is not None:
return variable_dict.get(key)
if default_identifier_map is not None and default_identifier_map.get(key):
return default_identifier_map.get(key)
return default_value
return cls( return cls(
key=variable_dict["key"], key=variable_dict["key"],
name=variable_dict["name"], name=variable_dict["name"],
scope=variable_dict.get("scope", "global"), scope=variable_dict["scope"],
scope_key=variable_dict.get("scope_key"), scope_key=_get_value("scope_key"),
sys_code=variable_dict.get("sys_code"), sys_code=_get_value("sys_code"),
user_name=variable_dict.get("user_name"), user_name=_get_value("user_name"),
) )
@@ -256,6 +270,7 @@ class StorageVariables(StorageItem):
encryption_method: Optional[str] = None encryption_method: Optional[str] = None
salt: Optional[str] = None salt: Optional[str] = None
enabled: int = 1 enabled: int = 1
description: Optional[str] = None
_identifier: VariablesIdentifier = dataclasses.field(init=False) _identifier: VariablesIdentifier = dataclasses.field(init=False)
@@ -297,6 +312,8 @@ class StorageVariables(StorageItem):
"category": self.category, "category": self.category,
"encryption_method": self.encryption_method, "encryption_method": self.encryption_method,
"salt": self.salt, "salt": self.salt,
"enabled": self.enabled,
"description": self.description,
} }
def from_object(self, other: "StorageVariables") -> None: def from_object(self, other: "StorageVariables") -> None:
@@ -311,6 +328,8 @@ class StorageVariables(StorageItem):
self.user_name = other.user_name self.user_name = other.user_name
self.encryption_method = other.encryption_method self.encryption_method = other.encryption_method
self.salt = other.salt self.salt = other.salt
self.enabled = other.enabled
self.description = other.description
@classmethod @classmethod
def from_identifier( def from_identifier(
@@ -347,7 +366,10 @@ class VariablesProvider(BaseComponent, ABC):
@abstractmethod @abstractmethod
def get( def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any: ) -> Any:
"""Query variables from storage.""" """Query variables from storage."""
@@ -416,9 +438,23 @@ class VariablesPlaceHolder:
self.full_key = full_key self.full_key = full_key
self.default_value = default_value self.default_value = default_value
def parse(self, variables_provider: VariablesProvider) -> Any: def parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables.""" """Parse the variables."""
return variables_provider.get(self.full_key, self.default_value) try:
return variables_provider.get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e
def __repr__(self): def __repr__(self):
"""Return the representation of the variables place holder.""" """Return the representation of the variables place holder."""
@@ -449,10 +485,13 @@ class StorageVariablesProvider(VariablesProvider):
self.system_app = system_app self.system_app = system_app
def get( def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any: ) -> Any:
"""Query variables from storage.""" """Query variables from storage."""
key = VariablesIdentifier.from_str_identifier(full_key) key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables) variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
if variable is None: if variable is None:
if default_value == _EMPTY_DEFAULT_VALUE: if default_value == _EMPTY_DEFAULT_VALUE:
@@ -641,7 +680,10 @@ class BuiltinVariablesProvider(VariablesProvider, ABC):
self.system_app = system_app self.system_app = system_app
def get( def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any: ) -> Any:
"""Query variables from storage.""" """Query variables from storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support get.") raise NotImplementedError("BuiltinVariablesProvider does not support get.")

View File

@@ -49,6 +49,7 @@ class ServeEntity(Model):
editable = Column( editable = Column(
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable" Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
) )
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
user_name = Column(String(128), index=True, nullable=True, comment="User name") user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code") sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -109,14 +110,14 @@ class VariablesEntity(Model):
String(32), String(32),
default="global", default="global",
nullable=True, nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow:uid," comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
"flow:dag_name,agent:agent_name) etc", "etc)",
) )
scope_key = Column( scope_key = Column(
String(256), String(256),
nullable=True, nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow:uid', " comment="Variable scope key, default is empty, for scope is 'flow_priv', "
"the scope_key is uid of flow", "the scope_key is dag id of flow",
) )
enabled = Column( enabled = Column(
Integer, Integer,
@@ -124,6 +125,7 @@ class VariablesEntity(Model):
nullable=True, nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled", comment="Variable enabled, 0: disabled, 1: enabled",
) )
description = Column(Text, nullable=True, comment="Variable description")
user_name = Column(String(128), index=True, nullable=True, comment="User name") user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code") sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -154,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
error_message = request_dict.get("error_message") error_message = request_dict.get("error_message")
if error_message: if error_message:
error_message = error_message[:500] error_message = error_message[:500]
variables_raw = request_dict.get("variables")
variables = (
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
)
new_dict = { new_dict = {
"uid": request_dict.get("uid"), "uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"), "dag_id": request_dict.get("dag_id"),
@@ -169,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"define_type": request_dict.get("define_type"), "define_type": request_dict.get("define_type"),
"editable": ServeEntity.parse_editable(request_dict.get("editable")), "editable": ServeEntity.parse_editable(request_dict.get("editable")),
"description": request_dict.get("description"), "description": request_dict.get("description"),
"variables": variables,
"user_name": request_dict.get("user_name"), "user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"), "sys_code": request_dict.get("sys_code"),
} }
@@ -185,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
REQ: The request REQ: The request
""" """
flow_data = json.loads(entity.flow_data) flow_data = json.loads(entity.flow_data)
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServeRequest( return ServeRequest(
uid=entity.uid, uid=entity.uid,
dag_id=entity.dag_id, dag_id=entity.dag_id,
@@ -200,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
define_type=entity.define_type, define_type=entity.define_type,
editable=ServeEntity.to_bool_editable(entity.editable), editable=ServeEntity.to_bool_editable(entity.editable),
description=entity.description, description=entity.description,
variables=variables,
user_name=entity.user_name, user_name=entity.user_name,
sys_code=entity.sys_code, sys_code=entity.sys_code,
) )
@@ -216,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_data = json.loads(entity.flow_data) flow_data = json.loads(entity.flow_data)
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServerResponse( return ServerResponse(
uid=entity.uid, uid=entity.uid,
dag_id=entity.dag_id, dag_id=entity.dag_id,
@@ -231,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
version=entity.version, version=entity.version,
editable=ServeEntity.to_bool_editable(entity.editable), editable=ServeEntity.to_bool_editable(entity.editable),
define_type=entity.define_type, define_type=entity.define_type,
variables=variables,
user_name=entity.user_name, user_name=entity.user_name,
sys_code=entity.sys_code, sys_code=entity.sys_code,
gmt_created=gmt_created_str, gmt_created=gmt_created_str,
@@ -271,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
entry.editable = ServeEntity.parse_editable(update_request.editable) entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.define_type: if update_request.define_type:
entry.define_type = update_request.define_type entry.define_type = update_request.define_type
if update_request.variables:
variables_raw = update_request.get_variables_dict()
entry.variables = (
json.dumps(variables_raw, ensure_ascii=False)
if variables_raw
else None
)
if update_request.user_name: if update_request.user_name:
entry.user_name = update_request.user_name entry.user_name = update_request.user_name
if update_request.sys_code: if update_request.sys_code:
@@ -317,6 +339,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
"enabled": enabled, "enabled": enabled,
"user_name": request_dict.get("user_name"), "user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"), "sys_code": request_dict.get("sys_code"),
"description": request_dict.get("description"),
} }
entity = VariablesEntity(**new_dict) entity = VariablesEntity(**new_dict)
return entity return entity
@@ -348,6 +371,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
enabled=enabled, enabled=enabled,
user_name=entity.user_name, user_name=entity.user_name,
sys_code=entity.sys_code, sys_code=entity.sys_code,
description=entity.description,
) )
def to_response(self, entity: VariablesEntity) -> VariablesResponse: def to_response(self, entity: VariablesEntity) -> VariablesResponse:
@@ -382,4 +406,5 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
sys_code=entity.sys_code, sys_code=entity.sys_code,
gmt_created=gmt_created_str, gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str, gmt_modified=gmt_modified_str,
description=entity.description,
) )

View File

@@ -29,6 +29,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
scope_key=item.scope_key, scope_key=item.scope_key,
sys_code=item.sys_code, sys_code=item.sys_code,
user_name=item.user_name, user_name=item.user_name,
description=item.description,
) )
def from_storage_format(self, model: VariablesEntity) -> StorageVariables: def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
@@ -46,6 +47,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
scope_key=model.scope_key, scope_key=model.scope_key,
sys_code=model.sys_code, sys_code=model.sys_code,
user_name=model.user_name, user_name=model.user_name,
description=model.description,
) )
def get_query_for_identifier( def get_query_for_identifier(

View File

@@ -90,6 +90,8 @@ class VariablesService(
scope_key=request.scope_key, scope_key=request.scope_key,
user_name=request.user_name, user_name=request.user_name,
sys_code=request.sys_code, sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
) )
self.variables_provider.save(variables) self.variables_provider.save(variables)
query = { query = {
@@ -123,6 +125,8 @@ class VariablesService(
scope_key=request.scope_key, scope_key=request.scope_key,
user_name=request.user_name, user_name=request.user_name,
sys_code=request.sys_code, sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
) )
exist_value = self.variables_provider.get( exist_value = self.variables_provider.get(
variables.identifier.str_identifier, None variables.identifier.str_identifier, None