mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(core): Support dag scope variables
This commit is contained in:
@@ -347,9 +347,10 @@ CREATE TABLE `dbgpt_serve_variables` (
|
||||
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
|
||||
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
|
||||
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
|
||||
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc',
|
||||
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow',
|
||||
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
|
||||
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
|
||||
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
|
||||
`description` text DEFAULT NULL COMMENT 'Variable description',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',
|
||||
|
@@ -101,9 +101,10 @@ CREATE TABLE `dbgpt_serve_variables` (
|
||||
`category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)',
|
||||
`encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)',
|
||||
`salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt',
|
||||
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc',
|
||||
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow',
|
||||
`scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)',
|
||||
`scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow',
|
||||
`enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled',
|
||||
`description` text DEFAULT NULL COMMENT 'Variable description',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time',
|
||||
|
@@ -20,3 +20,11 @@ del load_dotenv
|
||||
TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE = "knowledge_factory_domain_type"
|
||||
TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE = "knowledge_chat_domain_type"
|
||||
DOMAIN_TYPE_FINANCIAL_REPORT = "FinancialReport"
|
||||
|
||||
VARIABLES_SCOPE_GLOBAL = "global"
|
||||
VARIABLES_SCOPE_APP = "app"
|
||||
VARIABLES_SCOPE_AGENT = "agent"
|
||||
VARIABLES_SCOPE_FLOW = "flow"
|
||||
VARIABLES_SCOPE_DATASOURCE = "datasource"
|
||||
VARIABLES_SCOPE_FLOW_PRIVATE = "flow_priv"
|
||||
VARIABLES_SCOPE_AGENT_PRIVATE = "agent_priv"
|
||||
|
@@ -5,6 +5,7 @@ DAG is the core component of AWEL, it is used to define the relationship between
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@@ -17,6 +18,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@@ -489,6 +491,100 @@ def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DAGVariablesItem:
|
||||
"""The DAG variables item.
|
||||
|
||||
It is a private class, just used for internal.
|
||||
"""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
label: str
|
||||
value: Any
|
||||
category: Literal["common", "secret"] = "common"
|
||||
scope: str = "global"
|
||||
value_type: Optional[str] = None
|
||||
scope_key: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DAGVariables:
|
||||
"""The DAG variables."""
|
||||
|
||||
items: List[_DAGVariablesItem] = dataclasses.field(default_factory=list)
|
||||
_cached_provider: Optional["VariablesProvider"] = None
|
||||
_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
|
||||
|
||||
def merge(self, dag_variables: "DAGVariables") -> "DAGVariables":
|
||||
"""Merge the DAG variables.
|
||||
|
||||
Args:
|
||||
dag_variables (DAGVariables): The DAG variables to merge
|
||||
"""
|
||||
|
||||
def _build_key(item: _DAGVariablesItem):
|
||||
key = "_".join([item.key, item.name, item.scope])
|
||||
if item.scope_key:
|
||||
key += f"_{item.scope_key}"
|
||||
if item.sys_code:
|
||||
key += f"_{item.sys_code}"
|
||||
if item.user_name:
|
||||
key += f"_{item.user_name}"
|
||||
return key
|
||||
|
||||
new_items = []
|
||||
exist_vars = set()
|
||||
for item in self.items:
|
||||
new_items.append(item)
|
||||
exist_vars.add(_build_key(item))
|
||||
for item in dag_variables.items:
|
||||
key = _build_key(item)
|
||||
if key not in exist_vars:
|
||||
new_items.append(item)
|
||||
return DAGVariables(
|
||||
items=new_items,
|
||||
_cached_provider=self._cached_provider or dag_variables._cached_provider,
|
||||
)
|
||||
|
||||
def to_provider(self) -> "VariablesProvider":
|
||||
"""Convert the DAG variables to variables provider.
|
||||
|
||||
Returns:
|
||||
VariablesProvider: The variables provider
|
||||
"""
|
||||
if not self._cached_provider:
|
||||
from ...interface.variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
# Create a new provider safely
|
||||
provider = StorageVariablesProvider()
|
||||
for item in self.items:
|
||||
storage_vars = StorageVariables(
|
||||
key=item.key,
|
||||
name=item.name,
|
||||
label=item.label,
|
||||
value=item.value,
|
||||
category=item.category,
|
||||
scope=item.scope,
|
||||
value_type=item.value_type,
|
||||
scope_key=item.scope_key,
|
||||
sys_code=item.sys_code,
|
||||
user_name=item.user_name,
|
||||
description=item.description,
|
||||
)
|
||||
provider.save(storage_vars)
|
||||
self._cached_provider = provider
|
||||
|
||||
return self._cached_provider
|
||||
|
||||
|
||||
class DAGContext:
|
||||
"""The context of current DAG, created when the DAG is running.
|
||||
|
||||
@@ -502,6 +598,7 @@ class DAGContext:
|
||||
event_loop_task_id: int,
|
||||
streaming_call: bool = False,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
@@ -511,6 +608,7 @@ class DAGContext:
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
|
||||
dag_variables (Optional[DAGVariables], optional): The DAG variables.
|
||||
"""
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
@@ -520,6 +618,7 @@ class DAGContext:
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
self._event_loop_task_id = event_loop_task_id
|
||||
self._dag_variables = dag_variables
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
@@ -653,6 +752,7 @@ class DAG:
|
||||
resource_group: Optional[ResourceGroup] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
description: Optional[str] = None,
|
||||
default_dag_variables: Optional[DAGVariables] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
@@ -666,6 +766,7 @@ class DAG:
|
||||
self._resource_group: Optional[ResourceGroup] = resource_group
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {}
|
||||
self._default_dag_variables = default_dag_variables
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
|
@@ -17,6 +17,7 @@ from dbgpt._private.pydantic import (
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
from dbgpt.core.awel.dag.base import DAG, DAGNode
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata
|
||||
|
||||
@@ -166,29 +167,23 @@ class FlowData(BaseModel):
|
||||
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
||||
|
||||
|
||||
class VariablesRequest(BaseModel):
|
||||
"""Variable request model.
|
||||
|
||||
For creating a new variable in the DB-GPT.
|
||||
"""
|
||||
|
||||
class _VariablesRequestBase(BaseModel):
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of the variable to create",
|
||||
examples=["dbgpt.model.openai.api_key"],
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the variable to create",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the variable to create",
|
||||
examples=["My First OpenAI Key"],
|
||||
)
|
||||
value: Any = Field(
|
||||
..., description="The value of the variable to create", examples=["1234567890"]
|
||||
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="The description of the variable to create",
|
||||
examples=["Your OpenAI API key"],
|
||||
)
|
||||
value_type: Literal["str", "int", "float", "bool"] = Field(
|
||||
"str",
|
||||
@@ -206,10 +201,26 @@ class VariablesRequest(BaseModel):
|
||||
examples=["global"],
|
||||
)
|
||||
scope_key: Optional[str] = Field(
|
||||
...,
|
||||
None,
|
||||
description="The scope key of the variable to create",
|
||||
examples=["dbgpt"],
|
||||
)
|
||||
|
||||
|
||||
class VariablesRequest(_VariablesRequestBase):
|
||||
"""Variable request model.
|
||||
|
||||
For creating a new variable in the DB-GPT.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the variable to create",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
value: Any = Field(
|
||||
..., description="The value of the variable to create", examples=["1234567890"]
|
||||
)
|
||||
enabled: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether the variable is enabled",
|
||||
@@ -219,6 +230,80 @@ class VariablesRequest(BaseModel):
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
||||
|
||||
class ParsedFlowVariables(BaseModel):
|
||||
"""Parsed variables for the flow."""
|
||||
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of the variable",
|
||||
examples=["dbgpt.model.openai.api_key"],
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="The name of the variable",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
scope: str = Field(
|
||||
...,
|
||||
description="The scope of the variable",
|
||||
examples=["global"],
|
||||
)
|
||||
scope_key: Optional[str] = Field(
|
||||
None,
|
||||
description="The scope key of the variable",
|
||||
examples=["dbgpt"],
|
||||
)
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
user_name: Optional[str] = Field(None, description="User name")
|
||||
|
||||
|
||||
class FlowVariables(_VariablesRequestBase):
|
||||
"""Variables for the flow."""
|
||||
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="The name of the variable",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
value: Optional[Any] = Field(
|
||||
None, description="The value of the variable", examples=["1234567890"]
|
||||
)
|
||||
parsed_variables: Optional[ParsedFlowVariables] = Field(
|
||||
None, description="The parsed variables, parsed from the value"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "parsed_variables" not in values:
|
||||
parsed_variables = cls.parse_value_to_variables(values.get("value"))
|
||||
if parsed_variables:
|
||||
values["parsed_variables"] = parsed_variables
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def parse_value_to_variables(cls, value: Any) -> Optional[ParsedFlowVariables]:
|
||||
"""Parse the value to variables.
|
||||
|
||||
Args:
|
||||
value (Any): The value to parse
|
||||
|
||||
Returns:
|
||||
Optional[ParsedFlowVariables]: The parsed variables, None if the value is
|
||||
invalid
|
||||
"""
|
||||
from ...interface.variables import _is_variable_format, parse_variable
|
||||
|
||||
if not value or not isinstance(value, str) or not _is_variable_format(value):
|
||||
return None
|
||||
|
||||
variable_dict = parse_variable(value)
|
||||
return ParsedFlowVariables(**variable_dict)
|
||||
|
||||
|
||||
class State(str, Enum):
|
||||
"""State of a flow panel."""
|
||||
|
||||
@@ -409,7 +494,7 @@ class FlowPanel(BaseModel):
|
||||
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
|
||||
default=None, description="The metadata of the flow"
|
||||
)
|
||||
variables: Optional[List[VariablesRequest]] = Field(
|
||||
variables: Optional[List[FlowVariables]] = Field(
|
||||
default=None, description="The variables of the flow"
|
||||
)
|
||||
authors: Optional[List[str]] = Field(
|
||||
@@ -437,6 +522,21 @@ class FlowPanel(BaseModel):
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
|
||||
def get_variables_dict(self) -> List[Dict[str, Any]]:
|
||||
"""Get the variables dict."""
|
||||
if not self.variables:
|
||||
return []
|
||||
return [v.dict() for v in self.variables]
|
||||
|
||||
@classmethod
|
||||
def parse_variables(
|
||||
cls, variables: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Optional[List[FlowVariables]]:
|
||||
"""Parse the variables."""
|
||||
if not variables:
|
||||
return None
|
||||
return [FlowVariables(**v) for v in variables]
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Flow factory."""
|
||||
@@ -657,10 +757,36 @@ class FlowFactory:
|
||||
dag_id: Optional[str] = None,
|
||||
) -> DAG:
|
||||
"""Build the DAG."""
|
||||
from ..dag.base import DAGVariables, _DAGVariablesItem
|
||||
|
||||
formatted_name = flow_panel.name.replace(" ", "_")
|
||||
if not dag_id:
|
||||
dag_id = f"{self._dag_prefix}_{formatted_name}_{flow_panel.uid}"
|
||||
with DAG(dag_id) as dag:
|
||||
|
||||
default_dag_variables: Optional[DAGVariables] = None
|
||||
if flow_panel.variables:
|
||||
variables = []
|
||||
for v in flow_panel.variables:
|
||||
scope_key = v.scope_key
|
||||
if v.scope == VARIABLES_SCOPE_FLOW_PRIVATE and not scope_key:
|
||||
scope_key = dag_id
|
||||
variables.append(
|
||||
_DAGVariablesItem(
|
||||
key=v.key,
|
||||
name=v.name, # type: ignore
|
||||
label=v.label,
|
||||
description=v.description,
|
||||
value_type=v.value_type,
|
||||
category=v.category,
|
||||
scope=v.scope,
|
||||
scope_key=scope_key,
|
||||
value=v.value,
|
||||
user_name=flow_panel.user_name,
|
||||
sys_code=flow_panel.sys_code,
|
||||
)
|
||||
)
|
||||
default_dag_variables = DAGVariables(items=variables)
|
||||
with DAG(dag_id, default_dag_variables=default_dag_variables) as dag:
|
||||
for key, task in key_to_tasks.items():
|
||||
if not task._node_id:
|
||||
task.set_node_id(dag._new_node_id())
|
||||
|
@@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
IOField,
|
||||
@@ -12,7 +13,12 @@ from dbgpt.core.awel.flow import (
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowData,
|
||||
FlowFactory,
|
||||
FlowPanel,
|
||||
FlowVariables,
|
||||
)
|
||||
|
||||
from ...tests.conftest import variables_provider
|
||||
|
||||
@@ -46,6 +52,28 @@ class MyVariablesOperator(MapOperator[str, str]):
|
||||
key="dbgpt.model.openai.model",
|
||||
),
|
||||
),
|
||||
Parameter.build_from(
|
||||
"DAG Var 1",
|
||||
"dag_var1",
|
||||
type=str,
|
||||
placeholder="Please select the DAG variable 1",
|
||||
description="The DAG variable 1.",
|
||||
options=VariablesDynamicOptions(),
|
||||
ui=ui.UIVariablesInput(
|
||||
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
),
|
||||
),
|
||||
Parameter.build_from(
|
||||
"DAG Var 2",
|
||||
"dag_var2",
|
||||
type=str,
|
||||
placeholder="Please select the DAG variable 2",
|
||||
description="The DAG variable 2.",
|
||||
options=VariablesDynamicOptions(),
|
||||
ui=ui.UIVariablesInput(
|
||||
key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
),
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
@@ -65,15 +93,21 @@ class MyVariablesOperator(MapOperator[str, str]):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, openai_api_key: str, model: str, **kwargs):
|
||||
def __init__(
|
||||
self, openai_api_key: str, model: str, dag_var1: str, dag_var2: str, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._openai_api_key = openai_api_key
|
||||
self._model = model
|
||||
self._dag_var1 = dag_var1
|
||||
self._dag_var2 = dag_var2
|
||||
|
||||
async def map(self, user_name: str) -> str:
|
||||
dict_dict = {
|
||||
"openai_api_key": self._openai_api_key,
|
||||
"model": self._model,
|
||||
"dag_var1": self._dag_var1,
|
||||
"dag_var2": self._dag_var2,
|
||||
}
|
||||
json_data = json.dumps(dict_dict, ensure_ascii=False)
|
||||
return "Your name is %s, and your model info is %s." % (user_name, json_data)
|
||||
@@ -117,6 +151,10 @@ def json_flow():
|
||||
"my_test_variables_operator": {
|
||||
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
|
||||
"model": "${dbgpt.model.openai.model:default_model@global}",
|
||||
"dag_var1": "${dbgpt.core.flow.params:name1@%s}"
|
||||
% VARIABLES_SCOPE_FLOW_PRIVATE,
|
||||
"dag_var2": "${dbgpt.core.flow.params:name2@%s}"
|
||||
% VARIABLES_SCOPE_FLOW_PRIVATE,
|
||||
}
|
||||
}
|
||||
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
|
||||
@@ -208,16 +246,49 @@ def json_flow():
|
||||
async def test_build_flow(json_flow, variables_provider):
|
||||
DAGVar.set_variables_provider(variables_provider)
|
||||
flow_data = FlowData(**json_flow)
|
||||
variables = [
|
||||
FlowVariables(
|
||||
key="dbgpt.core.flow.params",
|
||||
name="name1",
|
||||
label="Name 1",
|
||||
value="value1",
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
|
||||
# scope_key="my_test_flow",
|
||||
),
|
||||
FlowVariables(
|
||||
key="dbgpt.core.flow.params",
|
||||
name="name2",
|
||||
label="Name 2",
|
||||
value="value2",
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope=VARIABLES_SCOPE_FLOW_PRIVATE,
|
||||
# scope_key="my_test_flow",
|
||||
),
|
||||
]
|
||||
flow_panel = FlowPanel(
|
||||
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed"
|
||||
label="My Test Flow",
|
||||
name="my_test_flow",
|
||||
flow_data=flow_data,
|
||||
state="deployed",
|
||||
variables=variables,
|
||||
)
|
||||
factory = FlowFactory()
|
||||
dag = factory.build(flow_panel)
|
||||
|
||||
leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
|
||||
result = await leaf_node.call("Alice")
|
||||
expected_dict = {
|
||||
"openai_api_key": "my_openai_api_key",
|
||||
"model": "GPT-4o",
|
||||
"dag_var1": "value1",
|
||||
"dag_var2": "value2",
|
||||
}
|
||||
expected_dict_str = json.dumps(expected_dict, ensure_ascii=False)
|
||||
assert (
|
||||
result
|
||||
== "End operator received input: Your name is Alice, and your model info is "
|
||||
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.'
|
||||
== f"End operator received input: Your name is Alice, and your model info is "
|
||||
f"{expected_dict_str}."
|
||||
)
|
||||
|
@@ -20,6 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
from dbgpt.util.executor_utils import (
|
||||
AsyncToSyncIterator,
|
||||
BlockingFunction,
|
||||
@@ -28,7 +29,7 @@ from dbgpt.util.executor_utils import (
|
||||
)
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
|
||||
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar, DAGVariables
|
||||
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -58,6 +59,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
@@ -67,6 +69,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
dag_variables (DAGVariables): The DAG variables.
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final
|
||||
state and data.
|
||||
@@ -243,6 +246,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
@@ -252,6 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
dag_variables (DAGVariables): The DAG variables passed to current DAG.
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
@@ -259,13 +264,15 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
call_data = {"data": call_data}
|
||||
with root_tracer.start_span("dbgpt.awel.operator.call"):
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, exist_dag_ctx=dag_ctx
|
||||
self, call_data, exist_dag_ctx=dag_ctx, dag_variables=dag_variables
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
@@ -275,7 +282,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
dag_variables (DAGVariables): The DAG variables passed to current DAG.
|
||||
loop (asyncio.BaseEventLoop): The event loop to run the operator.
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
@@ -284,12 +294,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
loop = cast(asyncio.BaseEventLoop, loop)
|
||||
return loop.run_until_complete(self.call(call_data))
|
||||
return loop.run_until_complete(self.call(call_data, dag_ctx, dag_variables))
|
||||
|
||||
async def call_stream(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
@@ -299,7 +310,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
|
||||
dag_variables (DAGVariables): The DAG variables passed to current DAG.
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
@@ -307,7 +318,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
call_data = {"data": call_data}
|
||||
with root_tracer.start_span("dbgpt.awel.operator.call_stream"):
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
self,
|
||||
call_data,
|
||||
streaming_call=True,
|
||||
exist_dag_ctx=dag_ctx,
|
||||
dag_variables=dag_variables,
|
||||
)
|
||||
|
||||
task_output = out_ctx.current_task_context.task_output
|
||||
@@ -328,6 +343,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def _blocking_call_stream(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> Iterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
@@ -337,7 +354,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
dag_variables (DAGVariables): The DAG variables passed to current DAG.
|
||||
loop (asyncio.BaseEventLoop): The event loop to run the operator.
|
||||
Returns:
|
||||
Iterator[OUT]: An iterator over the output stream.
|
||||
"""
|
||||
@@ -345,7 +365,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
return AsyncToSyncIterator(self.call_stream(call_data), loop)
|
||||
return AsyncToSyncIterator(
|
||||
self.call_stream(call_data, dag_ctx, dag_variables), loop
|
||||
)
|
||||
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
@@ -373,20 +395,77 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Check if the operator can be skipped in the branch."""
|
||||
return self._can_skip_in_branch
|
||||
|
||||
async def _resolve_variables(self, _: DAGContext):
|
||||
from ...interface.variables import VariablesPlaceHolder
|
||||
async def _resolve_variables(self, dag_ctx: DAGContext):
|
||||
"""Resolve variables in the operator.
|
||||
|
||||
Some attributes of the operator may be VariablesPlaceHolder, which need to be
|
||||
resolved before the operator is executed.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
"""
|
||||
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
|
||||
|
||||
if not self._variables_provider:
|
||||
return
|
||||
|
||||
if dag_ctx._dag_variables:
|
||||
# Resolve variables in DAG context
|
||||
resolve_tasks = []
|
||||
resolve_items = []
|
||||
for item in dag_ctx._dag_variables.items:
|
||||
# TODO: Resolve variables just once?
|
||||
if isinstance(item.value, VariablesPlaceHolder):
|
||||
resolve_tasks.append(
|
||||
self.blocking_func_to_async(
|
||||
item.value.parse, self._variables_provider
|
||||
)
|
||||
)
|
||||
resolve_items.append(item)
|
||||
resolved_values = await asyncio.gather(*resolve_tasks)
|
||||
for item, rv in zip(resolve_items, resolved_values):
|
||||
item.value = rv
|
||||
dag_provider: Optional["VariablesProvider"] = None
|
||||
if dag_ctx._dag_variables:
|
||||
dag_provider = dag_ctx._dag_variables.to_provider()
|
||||
|
||||
# TODO: Resolve variables parallel
|
||||
for attr, value in self.__dict__.items():
|
||||
# Handle all attributes that are VariablesPlaceHolder
|
||||
if isinstance(value, VariablesPlaceHolder):
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse, self._variables_provider
|
||||
)
|
||||
logger.debug(
|
||||
f"Resolve variable {attr} with value {resolved_value} for {self}"
|
||||
)
|
||||
resolved_value: Any = None
|
||||
default_identifier_map = None
|
||||
id_key = VariablesIdentifier.from_str_identifier(value.full_key)
|
||||
if (
|
||||
id_key.scope == VARIABLES_SCOPE_FLOW_PRIVATE
|
||||
and id_key.scope_key is None
|
||||
and self.dag
|
||||
):
|
||||
default_identifier_map = {"scope_key": self.dag.dag_id}
|
||||
|
||||
if dag_provider:
|
||||
# First try to resolve the variable with the DAG variables
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
dag_provider,
|
||||
ignore_not_found_error=True,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
if resolved_value is None:
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
self._variables_provider,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
logger.debug(
|
||||
f"Resolve variable {attr} with value {resolved_value} for "
|
||||
f"{self} from system variables"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Resolve variable {attr} with value {resolved_value} for "
|
||||
f"{self} from DAG variables"
|
||||
)
|
||||
setattr(self, attr, resolved_value)
|
||||
|
||||
|
||||
|
@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, cast
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..dag.base import DAGContext, DAGVar, DAGVariables
|
||||
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operators.common_operator import BranchOperator
|
||||
from ..task.base import SKIP_DATA, TaskContext, TaskState
|
||||
@@ -46,6 +46,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
dag_variables: Optional[DAGVariables] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow.
|
||||
|
||||
@@ -57,6 +58,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
Defaults to False.
|
||||
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
|
||||
Defaults to None.
|
||||
dag_variables (Optional[DAGVariables], optional): The DAG variables.
|
||||
"""
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
@@ -71,12 +73,19 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node_outputs = exist_dag_ctx._node_to_outputs
|
||||
share_data = exist_dag_ctx._share_data
|
||||
event_loop_task_id = exist_dag_ctx._event_loop_task_id
|
||||
if dag_variables and exist_dag_ctx._dag_variables:
|
||||
# Merge dag variables, prefer the `dag_variables` in the parameter
|
||||
dag_variables = dag_variables.merge(exist_dag_ctx._dag_variables)
|
||||
if node.dag and not dag_variables and node.dag._default_dag_variables:
|
||||
# Use default dag variables if not set
|
||||
dag_variables = node.dag._default_dag_variables
|
||||
dag_ctx = DAGContext(
|
||||
event_loop_task_id=event_loop_task_id,
|
||||
node_to_outputs=node_outputs,
|
||||
share_data=share_data,
|
||||
streaming_call=streaming_call,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
dag_variables=dag_variables,
|
||||
)
|
||||
# if node.dag:
|
||||
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
|
@@ -212,11 +212,18 @@ class VariablesIdentifier(ResourceIdentifier):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier":
|
||||
def from_str_identifier(
|
||||
cls,
|
||||
str_identifier: str,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> "VariablesIdentifier":
|
||||
"""Create a VariablesIdentifier from a string identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The string identifier.
|
||||
default_identifier_map (Optional[Dict[str, str]]): The default identifier
|
||||
map, which contains the default values for the identifier. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
VariablesIdentifier: The VariablesIdentifier.
|
||||
@@ -229,13 +236,20 @@ class VariablesIdentifier(ResourceIdentifier):
|
||||
if not variable_dict.get("name"):
|
||||
raise ValueError("Invalid string identifier, must have name")
|
||||
|
||||
def _get_value(key, default_value: Optional[str] = None) -> Optional[str]:
|
||||
if variable_dict.get(key) is not None:
|
||||
return variable_dict.get(key)
|
||||
if default_identifier_map is not None and default_identifier_map.get(key):
|
||||
return default_identifier_map.get(key)
|
||||
return default_value
|
||||
|
||||
return cls(
|
||||
key=variable_dict["key"],
|
||||
name=variable_dict["name"],
|
||||
scope=variable_dict.get("scope", "global"),
|
||||
scope_key=variable_dict.get("scope_key"),
|
||||
sys_code=variable_dict.get("sys_code"),
|
||||
user_name=variable_dict.get("user_name"),
|
||||
scope=variable_dict["scope"],
|
||||
scope_key=_get_value("scope_key"),
|
||||
sys_code=_get_value("sys_code"),
|
||||
user_name=_get_value("user_name"),
|
||||
)
|
||||
|
||||
|
||||
@@ -256,6 +270,7 @@ class StorageVariables(StorageItem):
|
||||
encryption_method: Optional[str] = None
|
||||
salt: Optional[str] = None
|
||||
enabled: int = 1
|
||||
description: Optional[str] = None
|
||||
|
||||
_identifier: VariablesIdentifier = dataclasses.field(init=False)
|
||||
|
||||
@@ -297,6 +312,8 @@ class StorageVariables(StorageItem):
|
||||
"category": self.category,
|
||||
"encryption_method": self.encryption_method,
|
||||
"salt": self.salt,
|
||||
"enabled": self.enabled,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
def from_object(self, other: "StorageVariables") -> None:
|
||||
@@ -311,6 +328,8 @@ class StorageVariables(StorageItem):
|
||||
self.user_name = other.user_name
|
||||
self.encryption_method = other.encryption_method
|
||||
self.salt = other.salt
|
||||
self.enabled = other.enabled
|
||||
self.description = other.description
|
||||
|
||||
@classmethod
|
||||
def from_identifier(
|
||||
@@ -347,7 +366,10 @@ class VariablesProvider(BaseComponent, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
|
||||
@@ -416,9 +438,23 @@ class VariablesPlaceHolder:
|
||||
self.full_key = full_key
|
||||
self.default_value = default_value
|
||||
|
||||
def parse(self, variables_provider: VariablesProvider) -> Any:
|
||||
def parse(
|
||||
self,
|
||||
variables_provider: VariablesProvider,
|
||||
ignore_not_found_error: bool = False,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Parse the variables."""
|
||||
return variables_provider.get(self.full_key, self.default_value)
|
||||
try:
|
||||
return variables_provider.get(
|
||||
self.full_key,
|
||||
self.default_value,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
except ValueError as e:
|
||||
if ignore_not_found_error:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
@@ -449,10 +485,13 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
key = VariablesIdentifier.from_str_identifier(full_key)
|
||||
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
|
||||
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
|
||||
if variable is None:
|
||||
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||
@@ -641,7 +680,10 @@ class BuiltinVariablesProvider(VariablesProvider, ABC):
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support get.")
|
||||
|
@@ -49,6 +49,7 @@ class ServeEntity(Model):
|
||||
editable = Column(
|
||||
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
|
||||
)
|
||||
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -109,14 +110,14 @@ class VariablesEntity(Model):
|
||||
String(32),
|
||||
default="global",
|
||||
nullable=True,
|
||||
comment="Variable scope(global,flow,app,agent,datasource,flow:uid,"
|
||||
"flow:dag_name,agent:agent_name) etc",
|
||||
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
|
||||
"etc)",
|
||||
)
|
||||
scope_key = Column(
|
||||
String(256),
|
||||
nullable=True,
|
||||
comment="Variable scope key, default is empty, for scope is 'flow:uid', "
|
||||
"the scope_key is uid of flow",
|
||||
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
|
||||
"the scope_key is dag id of flow",
|
||||
)
|
||||
enabled = Column(
|
||||
Integer,
|
||||
@@ -124,6 +125,7 @@ class VariablesEntity(Model):
|
||||
nullable=True,
|
||||
comment="Variable enabled, 0: disabled, 1: enabled",
|
||||
)
|
||||
description = Column(Text, nullable=True, comment="Variable description")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -154,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
error_message = request_dict.get("error_message")
|
||||
if error_message:
|
||||
error_message = error_message[:500]
|
||||
|
||||
variables_raw = request_dict.get("variables")
|
||||
variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
|
||||
)
|
||||
new_dict = {
|
||||
"uid": request_dict.get("uid"),
|
||||
"dag_id": request_dict.get("dag_id"),
|
||||
@@ -169,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"define_type": request_dict.get("define_type"),
|
||||
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
|
||||
"description": request_dict.get("description"),
|
||||
"variables": variables,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
}
|
||||
@@ -185,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
REQ: The request
|
||||
"""
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServeRequest(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -200,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
define_type=entity.define_type,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
description=entity.description,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
@@ -216,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServerResponse(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -231,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
version=entity.version,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
define_type=entity.define_type,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
@@ -271,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.define_type:
|
||||
entry.define_type = update_request.define_type
|
||||
|
||||
if update_request.variables:
|
||||
variables_raw = update_request.get_variables_dict()
|
||||
entry.variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False)
|
||||
if variables_raw
|
||||
else None
|
||||
)
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
@@ -317,6 +339,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
"enabled": enabled,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
"description": request_dict.get("description"),
|
||||
}
|
||||
entity = VariablesEntity(**new_dict)
|
||||
return entity
|
||||
@@ -348,6 +371,7 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
description=entity.description,
|
||||
)
|
||||
|
||||
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
|
||||
@@ -382,4 +406,5 @@ class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
description=entity.description,
|
||||
)
|
||||
|
@@ -29,6 +29,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
|
||||
scope_key=item.scope_key,
|
||||
sys_code=item.sys_code,
|
||||
user_name=item.user_name,
|
||||
description=item.description,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
|
||||
@@ -46,6 +47,7 @@ class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
|
||||
scope_key=model.scope_key,
|
||||
sys_code=model.sys_code,
|
||||
user_name=model.user_name,
|
||||
description=model.description,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
|
@@ -90,6 +90,8 @@ class VariablesService(
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
@@ -123,6 +125,8 @@ class VariablesService(
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
exist_value = self.variables_provider.get(
|
||||
variables.identifier.str_identifier, None
|
||||
|
Reference in New Issue
Block a user