mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 10:00:17 +00:00
fix: Fix AWEL flow load variables bug
This commit is contained in:
parent
92c9695559
commit
f496bf3ac9
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core import (
|
from dbgpt.core import (
|
||||||
@ -91,7 +91,7 @@ class BaseHOLLMOperator(
|
|||||||
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
||||||
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
||||||
self._max_token_limit = max_token_limit
|
self._max_token_limit = max_token_limit
|
||||||
self._sub_compose_dag = self._build_conversation_composer_dag()
|
self._sub_compose_dag: Optional[DAG] = None
|
||||||
|
|
||||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
||||||
conv_serve = ConversationServe.get_instance(self.system_app)
|
conv_serve = ConversationServe.get_instance(self.system_app)
|
||||||
@ -166,7 +166,7 @@ class BaseHOLLMOperator(
|
|||||||
"messages": history_messages,
|
"messages": history_messages,
|
||||||
"prompt_dict": prompt_dict,
|
"prompt_dict": prompt_dict,
|
||||||
}
|
}
|
||||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0])
|
||||||
# Sub dag, use the same dag context in the parent dag
|
# Sub dag, use the same dag context in the parent dag
|
||||||
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
||||||
model_request = ModelRequest.build_request(
|
model_request = ModelRequest.build_request(
|
||||||
@ -184,6 +184,12 @@ class BaseHOLLMOperator(
|
|||||||
storage_conv.add_user_message(user_input)
|
storage_conv.add_user_message(user_input)
|
||||||
return model_request
|
return model_request
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sub_compose_dag(self) -> DAG:
|
||||||
|
if not self._sub_compose_dag:
|
||||||
|
self._sub_compose_dag = self._build_conversation_composer_dag()
|
||||||
|
return self._sub_compose_dag
|
||||||
|
|
||||||
def _build_storage(
|
def _build_storage(
|
||||||
self, req: CommonLLMHttpRequestBody
|
self, req: CommonLLMHttpRequestBody
|
||||||
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
||||||
@ -207,7 +213,11 @@ class BaseHOLLMOperator(
|
|||||||
return storage_conv, history_messages
|
return storage_conv, history_messages
|
||||||
|
|
||||||
def _build_conversation_composer_dag(self) -> DAG:
|
def _build_conversation_composer_dag(self) -> DAG:
|
||||||
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
|
default_dag_variables = self.dag._default_dag_variables if self.dag else None
|
||||||
|
with DAG(
|
||||||
|
"dbgpt_awel_app_chat_history_prompt_composer",
|
||||||
|
default_dag_variables=default_dag_variables,
|
||||||
|
) as composer_dag:
|
||||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||||
# History transform task
|
# History transform task
|
||||||
if self._history_merge_mode == "token":
|
if self._history_merge_mode == "token":
|
||||||
|
@ -421,7 +421,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
Args:
|
Args:
|
||||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||||
"""
|
"""
|
||||||
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
|
from ...interface.variables import (
|
||||||
|
VariablesIdentifier,
|
||||||
|
VariablesPlaceHolder,
|
||||||
|
is_variable_string,
|
||||||
|
)
|
||||||
|
|
||||||
if not self._variables_provider:
|
if not self._variables_provider:
|
||||||
return
|
return
|
||||||
@ -432,11 +436,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
resolve_items = []
|
resolve_items = []
|
||||||
for item in dag_ctx._dag_variables.items:
|
for item in dag_ctx._dag_variables.items:
|
||||||
# TODO: Resolve variables just once?
|
# TODO: Resolve variables just once?
|
||||||
|
if not item.value:
|
||||||
|
continue
|
||||||
|
if isinstance(item.value, str) and is_variable_string(item.value):
|
||||||
|
item.value = VariablesPlaceHolder(item.name, item.value)
|
||||||
if isinstance(item.value, VariablesPlaceHolder):
|
if isinstance(item.value, VariablesPlaceHolder):
|
||||||
resolve_tasks.append(
|
resolve_tasks.append(
|
||||||
self.blocking_func_to_async(
|
item.value.async_parse(self._variables_provider)
|
||||||
item.value.parse, self._variables_provider
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
resolve_items.append(item)
|
resolve_items.append(item)
|
||||||
resolved_values = await asyncio.gather(*resolve_tasks)
|
resolved_values = await asyncio.gather(*resolve_tasks)
|
||||||
@ -462,15 +468,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
|
|
||||||
if dag_provider:
|
if dag_provider:
|
||||||
# First try to resolve the variable with the DAG variables
|
# First try to resolve the variable with the DAG variables
|
||||||
resolved_value = await self.blocking_func_to_async(
|
resolved_value = await value.async_parse(
|
||||||
value.parse,
|
|
||||||
dag_provider,
|
dag_provider,
|
||||||
ignore_not_found_error=True,
|
ignore_not_found_error=True,
|
||||||
default_identifier_map=default_identifier_map,
|
default_identifier_map=default_identifier_map,
|
||||||
)
|
)
|
||||||
if resolved_value is None:
|
if resolved_value is None:
|
||||||
resolved_value = await self.blocking_func_to_async(
|
resolved_value = await value.async_parse(
|
||||||
value.parse,
|
|
||||||
self._variables_provider,
|
self._variables_provider,
|
||||||
default_identifier_map=default_identifier_map,
|
default_identifier_map=default_identifier_map,
|
||||||
)
|
)
|
||||||
|
@ -4,12 +4,14 @@ from typing import AsyncIterator, List
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from dbgpt.component import SystemApp
|
||||||
|
|
||||||
from ...interface.variables import (
|
from ...interface.variables import (
|
||||||
StorageVariables,
|
StorageVariables,
|
||||||
StorageVariablesProvider,
|
StorageVariablesProvider,
|
||||||
VariablesIdentifier,
|
VariablesIdentifier,
|
||||||
)
|
)
|
||||||
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||||
from ..task.task_impl import _is_async_iterator
|
from ..task.task_impl import _is_async_iterator
|
||||||
|
|
||||||
|
|
||||||
@ -104,7 +106,9 @@ async def stream_input_nodes(request):
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _create_variables(**kwargs):
|
async def _create_variables(**kwargs):
|
||||||
vp = StorageVariablesProvider()
|
sys_app = SystemApp()
|
||||||
|
DAGVar.set_current_system_app(sys_app)
|
||||||
|
vp = StorageVariablesProvider(system_app=sys_app)
|
||||||
vars = kwargs.get("vars")
|
vars = kwargs.get("vars")
|
||||||
if vars and isinstance(vars, dict):
|
if vars and isinstance(vars, dict):
|
||||||
for param_key, param_var in vars.items():
|
for param_key, param_var in vars.items():
|
||||||
|
@ -374,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Query variables from storage."""
|
"""Query variables from storage."""
|
||||||
|
|
||||||
|
async def async_get(
|
||||||
|
self,
|
||||||
|
full_key: str,
|
||||||
|
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||||
|
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Query variables from storage async."""
|
||||||
|
raise NotImplementedError("Current variables provider does not support async.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, variables_item: StorageVariables) -> None:
|
def save(self, variables_item: StorageVariables) -> None:
|
||||||
"""Save variables to storage."""
|
"""Save variables to storage."""
|
||||||
@ -457,6 +466,24 @@ class VariablesPlaceHolder:
|
|||||||
return None
|
return None
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def async_parse(
|
||||||
|
self,
|
||||||
|
variables_provider: VariablesProvider,
|
||||||
|
ignore_not_found_error: bool = False,
|
||||||
|
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""Parse the variables async."""
|
||||||
|
try:
|
||||||
|
return await variables_provider.async_get(
|
||||||
|
self.full_key,
|
||||||
|
self.default_value,
|
||||||
|
default_identifier_map=default_identifier_map,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if ignore_not_found_error:
|
||||||
|
return None
|
||||||
|
raise e
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Return the representation of the variables place holder."""
|
"""Return the representation of the variables place holder."""
|
||||||
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
||||||
@ -508,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider):
|
|||||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||||
return self._convert_to_value_type(variable)
|
return self._convert_to_value_type(variable)
|
||||||
|
|
||||||
|
async def async_get(
|
||||||
|
self,
|
||||||
|
full_key: str,
|
||||||
|
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||||
|
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Query variables from storage async."""
|
||||||
|
# Try to get variables from storage
|
||||||
|
value = await blocking_func_to_async_no_executor(
|
||||||
|
self.get,
|
||||||
|
full_key,
|
||||||
|
default_value=None,
|
||||||
|
default_identifier_map=default_identifier_map,
|
||||||
|
)
|
||||||
|
if value is not None:
|
||||||
|
return value
|
||||||
|
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
|
||||||
|
# Get all builtin variables
|
||||||
|
variables = await self.async_get_variables(
|
||||||
|
key=key.key,
|
||||||
|
scope=key.scope,
|
||||||
|
scope_key=key.scope_key,
|
||||||
|
sys_code=key.sys_code,
|
||||||
|
user_name=key.user_name,
|
||||||
|
)
|
||||||
|
values = [v for v in variables if v.name == key.name]
|
||||||
|
if not values:
|
||||||
|
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||||
|
raise ValueError(f"Variable {full_key} not found")
|
||||||
|
return default_value
|
||||||
|
if len(values) > 1:
|
||||||
|
raise ValueError(f"Multiple variables found for {full_key}")
|
||||||
|
|
||||||
|
variable = values[0]
|
||||||
|
return self._convert_to_value_type(variable)
|
||||||
|
|
||||||
def save(self, variables_item: StorageVariables) -> None:
|
def save(self, variables_item: StorageVariables) -> None:
|
||||||
"""Save variables to storage."""
|
"""Save variables to storage."""
|
||||||
if variables_item.category == "secret":
|
if variables_item.category == "secret":
|
||||||
@ -577,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider):
|
|||||||
)
|
)
|
||||||
if is_builtin:
|
if is_builtin:
|
||||||
return builtin_variables
|
return builtin_variables
|
||||||
executor_factory: Optional[
|
executor_factory: Optional[DefaultExecutorFactory] = None
|
||||||
DefaultExecutorFactory
|
if self.system_app:
|
||||||
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
|
executor_factory = DefaultExecutorFactory.get_instance(
|
||||||
|
self.system_app, default_component=None
|
||||||
|
)
|
||||||
if executor_factory:
|
if executor_factory:
|
||||||
return await blocking_func_to_async(
|
return await blocking_func_to_async(
|
||||||
executor_factory.create(),
|
executor_factory.create(),
|
||||||
|
@ -4,7 +4,11 @@ from typing import List, Optional, Union
|
|||||||
from sqlalchemy import URL
|
from sqlalchemy import URL
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.core.interface.variables import VariablesProvider
|
from dbgpt.core.interface.variables import (
|
||||||
|
FernetEncryption,
|
||||||
|
StorageVariablesProvider,
|
||||||
|
VariablesProvider,
|
||||||
|
)
|
||||||
from dbgpt.serve.core import BaseServe
|
from dbgpt.serve.core import BaseServe
|
||||||
from dbgpt.storage.metadata import DatabaseManager
|
from dbgpt.storage.metadata import DatabaseManager
|
||||||
|
|
||||||
@ -33,6 +37,7 @@ class Serve(BaseServe):
|
|||||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||||
try_create_tables: Optional[bool] = False,
|
try_create_tables: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if api_prefix is None:
|
if api_prefix is None:
|
||||||
api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"]
|
api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"]
|
||||||
if api_tags is None:
|
if api_tags is None:
|
||||||
@ -41,8 +46,15 @@ class Serve(BaseServe):
|
|||||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||||
)
|
)
|
||||||
self._db_manager: Optional[DatabaseManager] = None
|
self._db_manager: Optional[DatabaseManager] = None
|
||||||
self._variables_provider: Optional[VariablesProvider] = None
|
self._serve_config = ServeConfig.from_app_config(
|
||||||
self._serve_config: Optional[ServeConfig] = None
|
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||||
|
)
|
||||||
|
self._variables_provider: StorageVariablesProvider = StorageVariablesProvider(
|
||||||
|
storage=None,
|
||||||
|
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
||||||
|
system_app=system_app,
|
||||||
|
)
|
||||||
|
system_app.register_instance(self._variables_provider)
|
||||||
|
|
||||||
def init_app(self, system_app: SystemApp):
|
def init_app(self, system_app: SystemApp):
|
||||||
if self._app_has_initiated:
|
if self._app_has_initiated:
|
||||||
@ -65,10 +77,6 @@ class Serve(BaseServe):
|
|||||||
|
|
||||||
def before_start(self):
|
def before_start(self):
|
||||||
"""Called before the start of the application."""
|
"""Called before the start of the application."""
|
||||||
from dbgpt.core.interface.variables import (
|
|
||||||
FernetEncryption,
|
|
||||||
StorageVariablesProvider,
|
|
||||||
)
|
|
||||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
|
||||||
@ -76,9 +84,6 @@ class Serve(BaseServe):
|
|||||||
from .models.variables_adapter import VariablesAdapter
|
from .models.variables_adapter import VariablesAdapter
|
||||||
|
|
||||||
self._db_manager = self.create_or_get_db_manager()
|
self._db_manager = self.create_or_get_db_manager()
|
||||||
self._serve_config = ServeConfig.from_app_config(
|
|
||||||
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
|
|
||||||
)
|
|
||||||
|
|
||||||
self._db_manager = self.create_or_get_db_manager()
|
self._db_manager = self.create_or_get_db_manager()
|
||||||
storage_adapter = VariablesAdapter()
|
storage_adapter = VariablesAdapter()
|
||||||
@ -89,11 +94,7 @@ class Serve(BaseServe):
|
|||||||
storage_adapter,
|
storage_adapter,
|
||||||
serializer,
|
serializer,
|
||||||
)
|
)
|
||||||
self._variables_provider = StorageVariablesProvider(
|
self._variables_provider.storage = storage
|
||||||
storage=storage,
|
|
||||||
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
|
||||||
system_app=self._system_app,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables_provider(self):
|
def variables_provider(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user