fix: Fix AWEL flow load variables bug

This commit is contained in:
Fangyin Cheng
2024-09-06 16:03:33 +08:00
parent 92c9695559
commit f496bf3ac9
5 changed files with 116 additions and 32 deletions

View File

@@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
"""
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
from ...interface.variables import (
VariablesIdentifier,
VariablesPlaceHolder,
is_variable_string,
)
if not self._variables_provider:
return
@@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
resolve_items = []
for item in dag_ctx._dag_variables.items:
# TODO: Resolve variables just once?
if not item.value:
continue
if isinstance(item.value, str) and is_variable_string(item.value):
item.value = VariablesPlaceHolder(item.name, item.value)
if isinstance(item.value, VariablesPlaceHolder):
resolve_tasks.append(
self.blocking_func_to_async(
item.value.parse, self._variables_provider
)
item.value.async_parse(self._variables_provider)
)
resolve_items.append(item)
resolved_values = await asyncio.gather(*resolve_tasks)
@@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if dag_provider:
# First try to resolve the variable with the DAG variables
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
dag_provider,
ignore_not_found_error=True,
default_identifier_map=default_identifier_map,
)
if resolved_value is None:
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
self._variables_provider,
default_identifier_map=default_identifier_map,
)

View File

@@ -4,12 +4,14 @@ from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from dbgpt.component import SystemApp
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator
@@ -104,7 +106,9 @@ async def stream_input_nodes(request):
@asynccontextmanager
async def _create_variables(**kwargs):
vp = StorageVariablesProvider()
sys_app = SystemApp()
DAGVar.set_current_system_app(sys_app)
vp = StorageVariablesProvider(system_app=sys_app)
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():