From f496bf3ac939a3c5f4e1c0f590ab3f4ccf41652c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 6 Sep 2024 16:03:33 +0800 Subject: [PATCH] fix: Fix AWEL flow load variables bug --- dbgpt/app/operators/llm.py | 18 ++++++-- dbgpt/core/awel/operators/base.py | 20 +++++---- dbgpt/core/awel/tests/conftest.py | 8 +++- dbgpt/core/interface/variables.py | 71 +++++++++++++++++++++++++++++-- dbgpt/serve/flow/serve.py | 31 +++++++------- 5 files changed, 116 insertions(+), 32 deletions(-) diff --git a/dbgpt/app/operators/llm.py b/dbgpt/app/operators/llm.py index 56b67a010..7ba44cb32 100644 --- a/dbgpt/app/operators/llm.py +++ b/dbgpt/app/operators/llm.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union, cast from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core import ( @@ -91,7 +91,7 @@ class BaseHOLLMOperator( self._keep_start_rounds = keep_start_rounds if self._has_history else 0 self._keep_end_rounds = keep_end_rounds if self._has_history else 0 self._max_token_limit = max_token_limit - self._sub_compose_dag = self._build_conversation_composer_dag() + self._sub_compose_dag: Optional[DAG] = None async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]: conv_serve = ConversationServe.get_instance(self.system_app) @@ -166,7 +166,7 @@ class BaseHOLLMOperator( "messages": history_messages, "prompt_dict": prompt_dict, } - end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0]) # Sub dag, use the same dag context in the parent dag messages = await end_node.call(call_data, dag_ctx=self.current_dag_context) model_request = ModelRequest.build_request( @@ -184,6 +184,12 @@ class BaseHOLLMOperator( storage_conv.add_user_message(user_input) return model_request + @property + def sub_compose_dag(self) -> DAG: + if not self._sub_compose_dag: + self._sub_compose_dag = self._build_conversation_composer_dag() + return self._sub_compose_dag + def _build_storage( self, req: CommonLLMHttpRequestBody ) -> Tuple[StorageConversation, List[BaseMessage]]: @@ -207,7 +213,11 @@ class BaseHOLLMOperator( return storage_conv, history_messages def _build_conversation_composer_dag(self) -> DAG: - with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: + default_dag_variables = self.dag._default_dag_variables if self.dag else None + with DAG( + "dbgpt_awel_app_chat_history_prompt_composer", + default_dag_variables=default_dag_variables, + ) as composer_dag: input_task = InputOperator(input_source=SimpleCallDataInputSource()) # History transform task if self._history_merge_mode == "token": diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index 7c66c0adc..aa4940f35 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): Args: dag_ctx (DAGContext): The context of the DAG when this node is run. """ - from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder + from ...interface.variables import ( + VariablesIdentifier, + VariablesPlaceHolder, + is_variable_string, + ) if not self._variables_provider: return @@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): resolve_items = [] for item in dag_ctx._dag_variables.items: # TODO: Resolve variables just once? + if not item.value: + continue + if isinstance(item.value, str) and is_variable_string(item.value): + item.value = VariablesPlaceHolder(item.name, item.value) if isinstance(item.value, VariablesPlaceHolder): resolve_tasks.append( - self.blocking_func_to_async( - item.value.parse, self._variables_provider - ) + item.value.async_parse(self._variables_provider) ) resolve_items.append(item) resolved_values = await asyncio.gather(*resolve_tasks) @@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if dag_provider: # First try to resolve the variable with the DAG variables - resolved_value = await self.blocking_func_to_async( - value.parse, + resolved_value = await value.async_parse( dag_provider, ignore_not_found_error=True, default_identifier_map=default_identifier_map, ) if resolved_value is None: - resolved_value = await self.blocking_func_to_async( - value.parse, + resolved_value = await value.async_parse( self._variables_provider, default_identifier_map=default_identifier_map, ) diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index 607783028..2341b6602 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -4,12 +4,14 @@ from typing import AsyncIterator, List import pytest import pytest_asyncio +from dbgpt.component import SystemApp + from ...interface.variables import ( StorageVariables, StorageVariablesProvider, VariablesIdentifier, ) -from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource +from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource from ..task.task_impl import _is_async_iterator @@ -104,7 +106,9 @@ async def stream_input_nodes(request): @asynccontextmanager async def _create_variables(**kwargs): - vp = StorageVariablesProvider() + sys_app = SystemApp() + DAGVar.set_current_system_app(sys_app) + vp = StorageVariablesProvider(system_app=sys_app) vars = kwargs.get("vars") if vars and isinstance(vars, dict): for param_key, param_var in vars.items(): diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 5d538ad25..b932298e8 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -374,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC): ) -> Any: """Query variables from storage.""" + async def async_get( + self, + full_key: str, + default_value: Optional[str] = _EMPTY_DEFAULT_VALUE, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> Any: + """Query variables from storage async.""" + raise NotImplementedError("Current variables provider does not support async.") + @abstractmethod def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" @@ -457,6 +466,24 @@ class VariablesPlaceHolder: return None raise e + async def async_parse( + self, + variables_provider: VariablesProvider, + ignore_not_found_error: bool = False, + default_identifier_map: Optional[Dict[str, str]] = None, + ): + """Parse the variables async.""" + try: + return await variables_provider.async_get( + self.full_key, + self.default_value, + default_identifier_map=default_identifier_map, + ) + except ValueError as e: + if ignore_not_found_error: + return None + raise e + def __repr__(self): """Return the representation of the variables place holder.""" return f"" @@ -508,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider): variable.value = self.encryption.decrypt(variable.value, variable.salt) return self._convert_to_value_type(variable) + async def async_get( + self, + full_key: str, + default_value: Optional[str] = _EMPTY_DEFAULT_VALUE, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> Any: + """Query variables from storage async.""" + # Try to get variables from storage + value = await blocking_func_to_async_no_executor( + self.get, + full_key, + default_value=None, + default_identifier_map=default_identifier_map, + ) + if value is not None: + return value + key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map) + # Get all builtin variables + variables = await self.async_get_variables( + key=key.key, + scope=key.scope, + scope_key=key.scope_key, + sys_code=key.sys_code, + user_name=key.user_name, + ) + values = [v for v in variables if v.name == key.name] + if not values: + if default_value == _EMPTY_DEFAULT_VALUE: + raise ValueError(f"Variable {full_key} not found") + return default_value + if len(values) > 1: + raise ValueError(f"Multiple variables found for {full_key}") + + variable = values[0] + return self._convert_to_value_type(variable) + def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" if variables_item.category == "secret": @@ -577,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider): ) if is_builtin: return builtin_variables - executor_factory: Optional[ - DefaultExecutorFactory - ] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None) + executor_factory: Optional[DefaultExecutorFactory] = None + if self.system_app: + executor_factory = DefaultExecutorFactory.get_instance( + self.system_app, default_component=None + ) if executor_factory: return await blocking_func_to_async( executor_factory.create(), diff --git a/dbgpt/serve/flow/serve.py b/dbgpt/serve/flow/serve.py index a27e3d28f..a8d0161f9 100644 --- a/dbgpt/serve/flow/serve.py +++ b/dbgpt/serve/flow/serve.py @@ -4,7 +4,11 @@ from typing import List, Optional, Union from sqlalchemy import URL from dbgpt.component import SystemApp -from dbgpt.core.interface.variables import VariablesProvider +from dbgpt.core.interface.variables import ( + FernetEncryption, + StorageVariablesProvider, + VariablesProvider, +) from dbgpt.serve.core import BaseServe from dbgpt.storage.metadata import DatabaseManager @@ -33,6 +37,7 @@ class Serve(BaseServe): db_url_or_db: Union[str, URL, DatabaseManager] = None, try_create_tables: Optional[bool] = False, ): + if api_prefix is None: api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"] if api_tags is None: @@ -41,8 +46,15 @@ class Serve(BaseServe): system_app, api_prefix, api_tags, db_url_or_db, try_create_tables ) self._db_manager: Optional[DatabaseManager] = None - self._variables_provider: Optional[VariablesProvider] = None - self._serve_config: Optional[ServeConfig] = None + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._variables_provider: StorageVariablesProvider = StorageVariablesProvider( + storage=None, + encryption=FernetEncryption(self._serve_config.encrypt_key), + system_app=system_app, + ) + system_app.register_instance(self._variables_provider) def init_app(self, system_app: SystemApp): if self._app_has_initiated: @@ -65,10 +77,6 @@ class Serve(BaseServe): def before_start(self): """Called before the start of the application.""" - from dbgpt.core.interface.variables import ( - FernetEncryption, - StorageVariablesProvider, - ) from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage from dbgpt.util.serialization.json_serialization import JsonSerializer @@ -76,9 +84,6 @@ class Serve(BaseServe): from .models.variables_adapter import VariablesAdapter self._db_manager = self.create_or_get_db_manager() - self._serve_config = ServeConfig.from_app_config( - self._system_app.config, SERVE_CONFIG_KEY_PREFIX - ) self._db_manager = self.create_or_get_db_manager() storage_adapter = VariablesAdapter() @@ -89,11 +94,7 @@ class Serve(BaseServe): storage_adapter, serializer, ) - self._variables_provider = StorageVariablesProvider( - storage=storage, - encryption=FernetEncryption(self._serve_config.encrypt_key), - system_app=self._system_app, - ) + self._variables_provider.storage = storage @property def variables_provider(self):