mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 13:00:02 +00:00
fix: Fix AWEL flow load variables bug
This commit is contained in:
@@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
"""
|
||||
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
|
||||
from ...interface.variables import (
|
||||
VariablesIdentifier,
|
||||
VariablesPlaceHolder,
|
||||
is_variable_string,
|
||||
)
|
||||
|
||||
if not self._variables_provider:
|
||||
return
|
||||
@@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
resolve_items = []
|
||||
for item in dag_ctx._dag_variables.items:
|
||||
# TODO: Resolve variables just once?
|
||||
if not item.value:
|
||||
continue
|
||||
if isinstance(item.value, str) and is_variable_string(item.value):
|
||||
item.value = VariablesPlaceHolder(item.name, item.value)
|
||||
if isinstance(item.value, VariablesPlaceHolder):
|
||||
resolve_tasks.append(
|
||||
self.blocking_func_to_async(
|
||||
item.value.parse, self._variables_provider
|
||||
)
|
||||
item.value.async_parse(self._variables_provider)
|
||||
)
|
||||
resolve_items.append(item)
|
||||
resolved_values = await asyncio.gather(*resolve_tasks)
|
||||
@@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if dag_provider:
|
||||
# First try to resolve the variable with the DAG variables
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
resolved_value = await value.async_parse(
|
||||
dag_provider,
|
||||
ignore_not_found_error=True,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
if resolved_value is None:
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
resolved_value = await value.async_parse(
|
||||
self._variables_provider,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
|
@@ -4,12 +4,14 @@ from typing import AsyncIterator, List
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ...interface.variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
)
|
||||
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@@ -104,7 +106,9 @@ async def stream_input_nodes(request):
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_variables(**kwargs):
|
||||
vp = StorageVariablesProvider()
|
||||
sys_app = SystemApp()
|
||||
DAGVar.set_current_system_app(sys_app)
|
||||
vp = StorageVariablesProvider(system_app=sys_app)
|
||||
vars = kwargs.get("vars")
|
||||
if vars and isinstance(vars, dict):
|
||||
for param_key, param_var in vars.items():
|
||||
|
Reference in New Issue
Block a user