mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
fix: Fix AWEL flow load variables bug
This commit is contained in:
parent
92c9695559
commit
f496bf3ac9
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
@ -91,7 +91,7 @@ class BaseHOLLMOperator(
|
||||
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
||||
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
||||
self._max_token_limit = max_token_limit
|
||||
self._sub_compose_dag = self._build_conversation_composer_dag()
|
||||
self._sub_compose_dag: Optional[DAG] = None
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
||||
conv_serve = ConversationServe.get_instance(self.system_app)
|
||||
@ -166,7 +166,7 @@ class BaseHOLLMOperator(
|
||||
"messages": history_messages,
|
||||
"prompt_dict": prompt_dict,
|
||||
}
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0])
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
||||
model_request = ModelRequest.build_request(
|
||||
@ -184,6 +184,12 @@ class BaseHOLLMOperator(
|
||||
storage_conv.add_user_message(user_input)
|
||||
return model_request
|
||||
|
||||
@property
|
||||
def sub_compose_dag(self) -> DAG:
|
||||
if not self._sub_compose_dag:
|
||||
self._sub_compose_dag = self._build_conversation_composer_dag()
|
||||
return self._sub_compose_dag
|
||||
|
||||
def _build_storage(
|
||||
self, req: CommonLLMHttpRequestBody
|
||||
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
||||
@ -207,7 +213,11 @@ class BaseHOLLMOperator(
|
||||
return storage_conv, history_messages
|
||||
|
||||
def _build_conversation_composer_dag(self) -> DAG:
|
||||
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
|
||||
default_dag_variables = self.dag._default_dag_variables if self.dag else None
|
||||
with DAG(
|
||||
"dbgpt_awel_app_chat_history_prompt_composer",
|
||||
default_dag_variables=default_dag_variables,
|
||||
) as composer_dag:
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
# History transform task
|
||||
if self._history_merge_mode == "token":
|
||||
|
@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
"""
|
||||
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
|
||||
from ...interface.variables import (
|
||||
VariablesIdentifier,
|
||||
VariablesPlaceHolder,
|
||||
is_variable_string,
|
||||
)
|
||||
|
||||
if not self._variables_provider:
|
||||
return
|
||||
@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
resolve_items = []
|
||||
for item in dag_ctx._dag_variables.items:
|
||||
# TODO: Resolve variables just once?
|
||||
if not item.value:
|
||||
continue
|
||||
if isinstance(item.value, str) and is_variable_string(item.value):
|
||||
item.value = VariablesPlaceHolder(item.name, item.value)
|
||||
if isinstance(item.value, VariablesPlaceHolder):
|
||||
resolve_tasks.append(
|
||||
self.blocking_func_to_async(
|
||||
item.value.parse, self._variables_provider
|
||||
)
|
||||
item.value.async_parse(self._variables_provider)
|
||||
)
|
||||
resolve_items.append(item)
|
||||
resolved_values = await asyncio.gather(*resolve_tasks)
|
||||
@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if dag_provider:
|
||||
# First try to resolve the variable with the DAG variables
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
resolved_value = await value.async_parse(
|
||||
dag_provider,
|
||||
ignore_not_found_error=True,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
if resolved_value is None:
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse,
|
||||
resolved_value = await value.async_parse(
|
||||
self._variables_provider,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
|
@ -4,12 +4,14 @@ from typing import AsyncIterator, List
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ...interface.variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
)
|
||||
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@ -104,7 +106,9 @@ async def stream_input_nodes(request):
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_variables(**kwargs):
|
||||
vp = StorageVariablesProvider()
|
||||
sys_app = SystemApp()
|
||||
DAGVar.set_current_system_app(sys_app)
|
||||
vp = StorageVariablesProvider(system_app=sys_app)
|
||||
vars = kwargs.get("vars")
|
||||
if vars and isinstance(vars, dict):
|
||||
for param_key, param_var in vars.items():
|
||||
|
@ -374,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC):
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
|
||||
async def async_get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage async."""
|
||||
raise NotImplementedError("Current variables provider does not support async.")
|
||||
|
||||
@abstractmethod
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
@ -457,6 +466,24 @@ class VariablesPlaceHolder:
|
||||
return None
|
||||
raise e
|
||||
|
||||
async def async_parse(
|
||||
self,
|
||||
variables_provider: VariablesProvider,
|
||||
ignore_not_found_error: bool = False,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Parse the variables async."""
|
||||
try:
|
||||
return await variables_provider.async_get(
|
||||
self.full_key,
|
||||
self.default_value,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
except ValueError as e:
|
||||
if ignore_not_found_error:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
||||
@ -508,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
async def async_get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage async."""
|
||||
# Try to get variables from storage
|
||||
value = await blocking_func_to_async_no_executor(
|
||||
self.get,
|
||||
full_key,
|
||||
default_value=None,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
if value is not None:
|
||||
return value
|
||||
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
|
||||
# Get all builtin variables
|
||||
variables = await self.async_get_variables(
|
||||
key=key.key,
|
||||
scope=key.scope,
|
||||
scope_key=key.scope_key,
|
||||
sys_code=key.sys_code,
|
||||
user_name=key.user_name,
|
||||
)
|
||||
values = [v for v in variables if v.name == key.name]
|
||||
if not values:
|
||||
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||
raise ValueError(f"Variable {full_key} not found")
|
||||
return default_value
|
||||
if len(values) > 1:
|
||||
raise ValueError(f"Multiple variables found for {full_key}")
|
||||
|
||||
variable = values[0]
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
if variables_item.category == "secret":
|
||||
@ -577,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
executor_factory: Optional[
|
||||
DefaultExecutorFactory
|
||||
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
|
||||
executor_factory: Optional[DefaultExecutorFactory] = None
|
||||
if self.system_app:
|
||||
executor_factory = DefaultExecutorFactory.get_instance(
|
||||
self.system_app, default_component=None
|
||||
)
|
||||
if executor_factory:
|
||||
return await blocking_func_to_async(
|
||||
executor_factory.create(),
|
||||
|
@ -4,7 +4,11 @@ from typing import List, Optional, Union
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.variables import VariablesProvider
|
||||
from dbgpt.core.interface.variables import (
|
||||
FernetEncryption,
|
||||
StorageVariablesProvider,
|
||||
VariablesProvider,
|
||||
)
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
@ -33,6 +37,7 @@ class Serve(BaseServe):
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
|
||||
if api_prefix is None:
|
||||
api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"]
|
||||
if api_tags is None:
|
||||
@ -41,8 +46,15 @@ class Serve(BaseServe):
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
self._variables_provider: Optional[VariablesProvider] = None
|
||||
self._serve_config: Optional[ServeConfig] = None
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._variables_provider: StorageVariablesProvider = StorageVariablesProvider(
|
||||
storage=None,
|
||||
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
||||
system_app=system_app,
|
||||
)
|
||||
system_app.register_instance(self._variables_provider)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
@ -65,10 +77,6 @@ class Serve(BaseServe):
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
from dbgpt.core.interface.variables import (
|
||||
FernetEncryption,
|
||||
StorageVariablesProvider,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
@ -76,9 +84,6 @@ class Serve(BaseServe):
|
||||
from .models.variables_adapter import VariablesAdapter
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
storage_adapter = VariablesAdapter()
|
||||
@ -89,11 +94,7 @@ class Serve(BaseServe):
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
self._variables_provider = StorageVariablesProvider(
|
||||
storage=storage,
|
||||
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
||||
system_app=self._system_app,
|
||||
)
|
||||
self._variables_provider.storage = storage
|
||||
|
||||
@property
|
||||
def variables_provider(self):
|
||||
|
Loading…
Reference in New Issue
Block a user