fix: Fix AWEL flow load variables bug

This commit is contained in:
Fangyin Cheng
2024-09-06 16:03:33 +08:00
parent 92c9695559
commit f496bf3ac9
5 changed files with 116 additions and 32 deletions

View File

@@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
"""
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
from ...interface.variables import (
VariablesIdentifier,
VariablesPlaceHolder,
is_variable_string,
)
if not self._variables_provider:
return
@@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
resolve_items = []
for item in dag_ctx._dag_variables.items:
# TODO: Resolve variables just once?
if not item.value:
continue
if isinstance(item.value, str) and is_variable_string(item.value):
item.value = VariablesPlaceHolder(item.name, item.value)
if isinstance(item.value, VariablesPlaceHolder):
resolve_tasks.append(
self.blocking_func_to_async(
item.value.parse, self._variables_provider
)
item.value.async_parse(self._variables_provider)
)
resolve_items.append(item)
resolved_values = await asyncio.gather(*resolve_tasks)
@@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if dag_provider:
# First try to resolve the variable with the DAG variables
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
dag_provider,
ignore_not_found_error=True,
default_identifier_map=default_identifier_map,
)
if resolved_value is None:
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
self._variables_provider,
default_identifier_map=default_identifier_map,
)

View File

@@ -4,12 +4,14 @@ from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from dbgpt.component import SystemApp
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator
@@ -104,7 +106,9 @@ async def stream_input_nodes(request):
@asynccontextmanager
async def _create_variables(**kwargs):
vp = StorageVariablesProvider()
sys_app = SystemApp()
DAGVar.set_current_system_app(sys_app)
vp = StorageVariablesProvider(system_app=sys_app)
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():

View File

@@ -374,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC):
) -> Any:
"""Query variables from storage."""
async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
raise NotImplementedError("Current variables provider does not support async.")
@abstractmethod
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
@@ -457,6 +466,24 @@ class VariablesPlaceHolder:
return None
raise e
async def async_parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables async."""
try:
return await variables_provider.async_get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e
def __repr__(self):
"""Return the representation of the variables place holder."""
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
@@ -508,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider):
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return self._convert_to_value_type(variable)
async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
# Try to get variables from storage
value = await blocking_func_to_async_no_executor(
self.get,
full_key,
default_value=None,
default_identifier_map=default_identifier_map,
)
if value is not None:
return value
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
# Get all builtin variables
variables = await self.async_get_variables(
key=key.key,
scope=key.scope,
scope_key=key.scope_key,
sys_code=key.sys_code,
user_name=key.user_name,
)
values = [v for v in variables if v.name == key.name]
if not values:
if default_value == _EMPTY_DEFAULT_VALUE:
raise ValueError(f"Variable {full_key} not found")
return default_value
if len(values) > 1:
raise ValueError(f"Multiple variables found for {full_key}")
variable = values[0]
return self._convert_to_value_type(variable)
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
if variables_item.category == "secret":
@@ -577,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider):
)
if is_builtin:
return builtin_variables
executor_factory: Optional[
DefaultExecutorFactory
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
executor_factory: Optional[DefaultExecutorFactory] = None
if self.system_app:
executor_factory = DefaultExecutorFactory.get_instance(
self.system_app, default_component=None
)
if executor_factory:
return await blocking_func_to_async(
executor_factory.create(),