diff --git a/.env.template b/.env.template index f4213accb..b15d5ce65 100644 --- a/.env.template +++ b/.env.template @@ -272,6 +272,12 @@ DBGPT_LOG_LEVEL=INFO # API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas. # API_KEYS=dbgpt +#*******************************************************************# +#** ENCRYPT **# +#*******************************************************************# +# ENCRYPT KEY - The key used to encrypt and decrypt the data +# ENCRYPT_KEY=your_secret_key + #*******************************************************************# #** Application Config **# diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 3fdc927c9..6af2a4932 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -303,6 +303,7 @@ class Config(metaclass=Singleton): ) # global dbgpt api key self.API_KEYS = os.getenv("API_KEYS", None) + self.ENCRYPT_KEY = os.getenv("ENCRYPT_KEY", "your_secret_key") # Non-streaming scene retries self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int( diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index e3f414cb0..8d4481f9d 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -1,5 +1,6 @@ """Import all models to make sure they are registered with SQLAlchemy. """ + from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity @@ -11,6 +12,7 @@ from dbgpt.serve.agent.app.recommend_question.recommend_question import ( from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity +from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity from dbgpt.storage.chat_history.chat_history_db import ( @@ -32,4 +34,5 @@ _MODELS = [ ModelInstanceEntity, FlowServeEntity, RecommendQuestionEntity, + FlowVariableEntity, ] diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 65620bff9..43ced6b25 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -7,6 +7,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS) + if cfg.ENCRYPT_KEY: + system_app.config.set("dbgpt.app.global.encrypt_key", cfg.ENCRYPT_KEY) # ################################ Prompt Serve Register Begin ###################################### from dbgpt.serve.prompt.serve import ( diff --git a/dbgpt/component.py b/dbgpt/component.py index bb7a7a9e4..cb88a61ec 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -89,6 +89,7 @@ class ComponentType(str, Enum): CONNECTOR_MANAGER = "dbgpt_connector_manager" AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" + VARIABLES_PROVIDER = "dbgpt_variables_provider" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index ddcfd52bc..512cd6126 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -11,7 +11,18 @@ import uuid from abc import ABC, abstractmethod from collections import deque from concurrent.futures import Executor -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Union, + cast, +) from dbgpt.component import SystemApp @@ -23,6 +34,9 @@ logger = logging.getLogger(__name__) DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + def _is_async_context(): try: @@ -128,6 +142,8 @@ class DAGVar: # The executor for current DAG, this is used run some sync tasks in async DAG _executor: Optional[Executor] = None + _variables_provider: Optional["VariablesProvider"] = None + @classmethod def enter_dag(cls, dag) -> None: """Enter a DAG context. @@ -221,6 +237,24 @@ class DAGVar: """ cls._executor = executor + @classmethod + def get_variables_provider(cls) -> Optional["VariablesProvider"]: + """Get the current variables provider. + + Returns: + Optional[VariablesProvider]: The current variables provider + """ + return cls._variables_provider + + @classmethod + def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None: + """Set the current variables provider. + + Args: + variables_provider (VariablesProvider): The variables provider to set + """ + cls._variables_provider = variables_provider + class DAGLifecycle: """The lifecycle of DAG.""" diff --git a/dbgpt/core/awel/flow/__init__.py b/dbgpt/core/awel/flow/__init__.py index 5a173565f..80db5b7e6 100644 --- a/dbgpt/core/awel/flow/__init__.py +++ b/dbgpt/core/awel/flow/__init__.py @@ -7,6 +7,7 @@ from ..util.parameter_util import ( # noqa: F401 BaseDynamicOptions, FunctionDynamicOptions, OptionValue, + VariablesDynamicOptions, ) from .base import ( # noqa: F401 IOField, @@ -35,4 +36,5 @@ __ALL__ = [ "IOField", "BaseDynamicOptions", "FunctionDynamicOptions", + "VariablesDynamicOptions", ] diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 6dd287e56..05912a466 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -6,7 +6,7 @@ import inspect from abc import ABC from datetime import date, datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast from dbgpt._private.pydantic import ( BaseModel, @@ -15,12 +15,14 @@ from dbgpt._private.pydantic import ( model_to_dict, model_validator, ) +from dbgpt.component import SystemApp from dbgpt.core.awel.util.parameter_util import ( BaseDynamicOptions, OptionValue, RefreshOptionRequest, ) from dbgpt.core.interface.serialization import Serializable +from dbgpt.util.executor_utils import DefaultExecutorFactory, blocking_func_to_async from .exceptions import FlowMetadataException, FlowParameterMetadataException from .ui import UIComponent @@ -490,11 +492,19 @@ class Parameter(TypeMetadata, Serializable): dict_value["ui"] = self.ui.to_dict() return dict_value - def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict: + async def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: """Refresh the options of the parameter. Args: request (RefreshOptionRequest): The request to refresh the options. + trigger (Literal["default", "http"], optional): The trigger type. + Defaults to "default". + system_app (Optional[SystemApp], optional): The system app. Returns: Dict: The response. @@ -503,7 +513,7 @@ class Parameter(TypeMetadata, Serializable): if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): - values = self.options.refresh(request) + values = self.options.refresh(request, trigger, system_app) dict_value["options"] = [value.to_dict() for value in values] else: dict_value["options"] = [value.to_dict() for value in self.options] @@ -793,18 +803,56 @@ class BaseMetadata(BaseResource): ] return dict_value - def refresh(self, request: List[RefreshOptionRequest]) -> Dict: - """Refresh the metadata.""" + async def refresh( + self, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: + """Refresh the metadata. + + Args: + request (List[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + executor = DefaultExecutorFactory.get_instance(system_app).create() + name_to_request = {req.name: req for req in request} parameter_requests = { parameter.name: name_to_request.get(parameter.name) for parameter in self.parameters } - dict_value = self.to_dict() - dict_value["parameters"] = [ - parameter.refresh(parameter_requests.get(parameter.name)) - for parameter in self.parameters - ] + dict_value = model_to_dict(self, exclude={"parameters"}) + parameters = [] + for parameter in self.parameters: + parameter_dict = parameter.to_dict() + parameter_request = parameter_requests.get(parameter.name) + if not parameter.options: + options = None + elif isinstance(parameter.options, BaseDynamicOptions): + options_obj = parameter.options + if options_obj.support_async(system_app, parameter_request): + values = await options_obj.async_refresh( + parameter_request, trigger, system_app + ) + else: + values = await blocking_func_to_async( + executor, + options_obj.refresh, + parameter_request, + trigger, + system_app, + ) + options = [value.to_dict() for value in values] + else: + options = [value.to_dict() for value in self.options] + parameter_dict["options"] = options + parameters.append(parameter_dict) + + dict_value["parameters"] = parameters + return dict_value @@ -1090,14 +1138,23 @@ class FlowRegistry: """Get the metadata list.""" return [item.metadata.to_dict() for item in self._registry.values()] - def refresh( - self, key: str, is_operator: bool, request: List[RefreshOptionRequest] + async def refresh( + self, + key: str, + is_operator: bool, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, ) -> Dict: """Refresh the metadata.""" if is_operator: - return _get_operator_class(key).metadata.refresh(request) # type: ignore + return await _get_operator_class(key).metadata.refresh( # type: ignore + request, trigger, system_app + ) else: - return _get_resource_class(key).metadata.refresh(request) + return await _get_resource_class(key).metadata.refresh( + request, trigger, system_app + ) _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 91008269e..66b413a9f 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -71,7 +71,7 @@ class PanelEditorMixin(BaseModel): class UIComponent(RefreshableMixin, Serializable, BaseModel): """UI component.""" - class UIAttribute(StatusMixin, BaseModel): + class UIAttribute(BaseModel): """Base UI attribute.""" disabled: bool = Field( @@ -106,7 +106,7 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel): class UISelect(UIComponent): """Select component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Select attribute.""" show_search: bool = Field( @@ -138,7 +138,7 @@ class UISelect(UIComponent): class UICascader(UIComponent): """Cascader component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Cascader attribute.""" show_search: bool = Field( @@ -178,7 +178,7 @@ class UICheckbox(UIComponent): class UIDatePicker(UIComponent): """Date picker component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Date picker attribute.""" placement: Optional[ @@ -199,7 +199,7 @@ class UIDatePicker(UIComponent): class UIInput(UIComponent): """Input component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Input attribute.""" prefix: Optional[str] = Field( @@ -216,7 +216,7 @@ class UIInput(UIComponent): None, description="Whether to show count", ) - maxlength: Optional[int] = Field( + max_length: Optional[int] = Field( None, description="The maximum length of the input", ) @@ -294,7 +294,7 @@ class UISlider(UIComponent): class UITimePicker(UIComponent): """Time picker component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Time picker attribute.""" format: Optional[str] = Field( @@ -377,15 +377,20 @@ class UIUpload(UIComponent): ) -class UIVariableInput(UIInput): - """Variable input component.""" +class UIVariablesInput(UIInput): + """Variables input component.""" - ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore + ui_type: Literal["variable"] = Field("variables", frozen=True) # type: ignore key: str = Field(..., description="The key of the variable") key_type: Literal["common", "secret"] = Field( "common", description="The type of the key", ) + scope: str = Field("global", description="The scope of the variables") + scope_key: Optional[str] = Field( + None, + description="The key of the scope", + ) refresh: Optional[bool] = Field( True, description="Whether to enable the refresh", @@ -396,7 +401,7 @@ class UIVariableInput(UIInput): self._check_options(parameter_dict.get("options", {})) -class UIPasswordInput(UIVariableInput): +class UIPasswordInput(UIVariablesInput): """Password input component.""" ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index aafa11f90..ac91b6b1a 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -2,10 +2,12 @@ import asyncio import functools +import logging from abc import ABC, ABCMeta, abstractmethod from contextvars import ContextVar from types import FunctionType from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Dict, @@ -29,6 +31,11 @@ from dbgpt.util.tracer import root_tracer from ..dag.base import DAG, DAGContext, DAGNode, DAGVar from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + +logger = logging.getLogger(__name__) + F = TypeVar("F", bound=FunctionType) CALL_DATA = Union[Dict[str, Any], Any] @@ -92,6 +99,9 @@ class BaseOperatorMeta(ABCMeta): kwargs.get("system_app") or DAGVar.get_current_system_app() ) executor = kwargs.get("executor") or DAGVar.get_executor() + variables_provider = ( + kwargs.get("variables_provider") or DAGVar.get_variables_provider() + ) if not executor: if system_app: executor = system_app.get_component( @@ -102,14 +112,24 @@ class BaseOperatorMeta(ABCMeta): else: executor = DefaultExecutorFactory().create() DAGVar.set_executor(executor) + if not variables_provider: + from ...interface.variables import VariablesProvider + + if system_app: + variables_provider = system_app.get_component( + ComponentType.VARIABLES_PROVIDER, + VariablesProvider, + default_component=None, + ) + else: + from ...interface.variables import StorageVariablesProvider + + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) if not task_id and dag: task_id = dag._new_node_id() runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner - # print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}") - # for arg in sig_cache.parameters: - # if arg not in kwargs: - # kwargs[arg] = default_args[arg] if not kwargs.get("dag"): kwargs["dag"] = dag if not kwargs.get("task_id"): @@ -120,6 +140,8 @@ class BaseOperatorMeta(ABCMeta): kwargs["system_app"] = system_app if not kwargs.get("executor"): kwargs["executor"] = executor + if not kwargs.get("variables_provider"): + kwargs["variables_provider"] = variables_provider real_obj = func(self, *args, **kwargs) return real_obj @@ -150,6 +172,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): dag: Optional[DAG] = None, runner: Optional[WorkflowRunner] = None, can_skip_in_branch: bool = True, + variables_provider: Optional["VariablesProvider"] = None, **kwargs, ) -> None: """Create a BaseOperator with an optional workflow runner. @@ -171,6 +194,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): self._runner: WorkflowRunner = runner self._dag_ctx: Optional[DAGContext] = None self._can_skip_in_branch = can_skip_in_branch + self._variables_provider = variables_provider @property def current_dag_context(self) -> DAGContext: @@ -199,6 +223,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if not task_log_id: raise ValueError(f"The task log ID can't be empty, current node {self}") CURRENT_DAG_CONTEXT.set(dag_ctx) + # Resolve variables + await self._resolve_variables(dag_ctx) return await self._do_run(dag_ctx) @abstractmethod @@ -347,6 +373,21 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): """Check if the operator can be skipped in the branch.""" return self._can_skip_in_branch + async def _resolve_variables(self, _: DAGContext): + from ...interface.variables import VariablesPlaceHolder + + if not self._variables_provider: + return + for attr, value in self.__dict__.items(): + if isinstance(value, VariablesPlaceHolder): + resolved_value = await self.blocking_func_to_async( + value.parse, self._variables_provider + ) + logger.debug( + f"Resolve variable {attr} with value {resolved_value} for {self}" + ) + setattr(self, attr, resolved_value) + def initialize_runner(runner: WorkflowRunner): """Initialize the default runner.""" diff --git a/dbgpt/core/awel/tests/test_dag_variables.py b/dbgpt/core/awel/tests/test_dag_variables.py new file mode 100644 index 000000000..88c9b6660 --- /dev/null +++ b/dbgpt/core/awel/tests/test_dag_variables.py @@ -0,0 +1,111 @@ +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio + +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, + VariablesPlaceHolder, +) +from .. import DAG, DAGVar, InputOperator, MapOperator, SimpleInputSource + + +class VariablesOperator(MapOperator[str, str]): + def __init__(self, int_var: int, str_var: str, secret: str, **kwargs): + super().__init__(**kwargs) + self._int_var = int_var + self._str_var = str_var + self._secret = secret + + async def map(self, x: str) -> str: + return ( + f"x: {x}, int_var: {self._int_var}, str_var: {self._str_var}, " + f"secret: {self._secret}" + ) + + +@pytest.fixture +def default_dag(): + with DAG("test_dag") as dag: + input_node = InputOperator(input_source=SimpleInputSource.from_callable()) + map_node = MapOperator(lambda x: x * 2) + input_node >> map_node + return dag + + +@asynccontextmanager +async def _create_variables(**kwargs): + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) + + vars = kwargs.get("vars") + variables = {} + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + variables_provider.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + variables[param_key] = VariablesPlaceHolder(param_key, key, value_type) + else: + raise ValueError("vars is required.") + + with DAG("simple_dag") as dag: + map_node = VariablesOperator(**variables) + yield map_node + + +@pytest_asyncio.fixture +async def variables_node(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as node: + yield node + + +@pytest.mark.asyncio +async def test_default_dag(default_dag: DAG): + leaf_node = default_dag.leaf_nodes[0] + res = await leaf_node.call(2) + assert res == 4 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_node", + [ + ( + { + "vars": { + "int_var": { + "key": "int_key@my_int_var@global", + "value": 0, + "value_type": "int", + }, + "str_var": { + "key": "str_key@my_str_var@global", + "value": "1", + "value_type": "str", + }, + "secret": { + "key": "secret_key@my_secret_var@global", + "value": "2131sdsdf", + "value_type": "str", + "category": "secret", + }, + } + } + ), + ], + indirect=["variables_node"], +) +async def test_input_nodes(variables_node: VariablesOperator): + res = await variables_node.call("test") + assert res == "x: test, int_var: 0, str_var: 1, secret: 2131sdsdf" diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index 2393aed89..a492169c5 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -2,9 +2,10 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from dbgpt._private.pydantic import BaseModel, Field, model_validator +from dbgpt.component import SystemApp from dbgpt.core.interface.serialization import Serializable _DEFAULT_DYNAMIC_REGISTRY = {} @@ -29,6 +30,21 @@ class RefreshOptionRequest(BaseModel): depends: Optional[List[RefreshOptionDependency]] = Field( None, description="The depends of the refresh config" ) + variables_key: Optional[str] = Field( + None, description="The variables key to refresh" + ) + variables_scope: Optional[str] = Field( + None, description="The variables scope to refresh" + ) + variables_scope_key: Optional[str] = Field( + None, description="The variables scope key to refresh" + ) + variables_sys_code: Optional[str] = Field( + None, description="The system code to refresh" + ) + variables_user_name: Optional[str] = Field( + None, description="The user name to refresh" + ) class OptionValue(Serializable, BaseModel): @@ -49,13 +65,57 @@ class OptionValue(Serializable, BaseModel): class BaseDynamicOptions(Serializable, BaseModel, ABC): """The base dynamic options.""" + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async. + + Args: + system_app (Optional[SystemApp]): The system app + request (Optional[RefreshOptionRequest]): The refresh request + + Returns: + bool: Whether the dynamic options support async + """ + return False + def option_values(self) -> List[OptionValue]: """Return the option values of the parameter.""" return self.refresh(None) @abstractmethod - def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: - """Refresh the dynamic options.""" + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + raise NotImplementedError("The dynamic options does not support async.") class FunctionDynamicOptions(BaseDynamicOptions): @@ -68,7 +128,12 @@ class FunctionDynamicOptions(BaseDynamicOptions): ..., description="The unique id of the function to generate the dynamic options" ) - def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: """Refresh the dynamic options.""" if not request or not request.depends: return self.func() @@ -96,6 +161,109 @@ class FunctionDynamicOptions(BaseDynamicOptions): return {"func_id": self.func_id} +class VariablesDynamicOptions(BaseDynamicOptions): + """The variables dynamic options.""" + + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async.""" + if not system_app or not request or not request.variables_key: + return False + + from ...interface.variables import BuiltinVariablesProvider + + provider: BuiltinVariablesProvider = system_app.get_component( + request.variables_key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False + return provider.support_async() + + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + # Only refresh when trigger is http and request is not None + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = vp.get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = await vp.async_get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return {"key": self.key} + + def _generate_unique_id(func: Callable) -> str: if func.__name__ == "": func_id = f"lambda_{inspect.getfile(func)}_{inspect.getsourcelines(func)}" diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index 2a61746ec..4bf152ab8 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast -from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.util.annotations import PublicAPI from dbgpt.util.i18n_utils import _ from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.serialization.json_serialization import JsonSerializer +from ..awel.flow import Parameter, ResourceCategory, register_resource + @PublicAPI(stability="beta") class ResourceIdentifier(Serializable, ABC): diff --git a/dbgpt/core/interface/tests/test_variables.py b/dbgpt/core/interface/tests/test_variables.py new file mode 100644 index 000000000..3b7ab8157 --- /dev/null +++ b/dbgpt/core/interface/tests/test_variables.py @@ -0,0 +1,114 @@ +import base64 +import os + +from cryptography.fernet import Fernet + +from ..variables import ( + FernetEncryption, + InMemoryStorage, + SimpleEncryption, + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, +) + + +def test_fernet_encryption(): + key = Fernet.generate_key() + encryption = FernetEncryption(key) + new_encryption = FernetEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + assert decrypted_data == new_encryption.decrypt(encrypted_data, salt) + + +def test_simple_encryption(): + key = base64.b64encode(os.urandom(32)).decode() + encryption = SimpleEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + + +def test_storage_variables_provider(): + storage = InMemoryStorage() + encryption = SimpleEncryption() + provider = StorageVariablesProvider(storage, encryption) + + full_key = "key@name@global" + value = "secret_value" + value_type = "str" + label = "test_label" + + id = VariablesIdentifier.from_str_identifier(full_key) + provider.save( + StorageVariables.from_identifier( + id, value, value_type, label, category="secret" + ) + ) + + loaded_variable_value = provider.get(full_key) + assert loaded_variable_value == value + + +def test_variables_identifier(): + full_key = "key@name@global@scope_key@sys_code@user_name" + identifier = VariablesIdentifier.from_str_identifier(full_key) + + assert identifier.key == "key" + assert identifier.name == "name" + assert identifier.scope == "global" + assert identifier.scope_key == "scope_key" + assert identifier.sys_code == "sys_code" + assert identifier.user_name == "user_name" + + str_identifier = identifier.str_identifier + assert str_identifier == full_key + + +def test_storage_variables(): + key = "test_key" + name = "test_name" + label = "test_label" + value = "test_value" + value_type = "str" + category = "common" + scope = "global" + + storage_variable = StorageVariables( + key=key, + name=name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=scope, + ) + + assert storage_variable.key == key + assert storage_variable.name == name + assert storage_variable.label == label + assert storage_variable.value == value + assert storage_variable.value_type == value_type + assert storage_variable.category == category + assert storage_variable.scope == scope + + dict_representation = storage_variable.to_dict() + assert dict_representation["key"] == key + assert dict_representation["name"] == name + assert dict_representation["label"] == label + assert dict_representation["value"] == value + assert dict_representation["value_type"] == value_type + assert dict_representation["category"] == category + assert dict_representation["scope"] == scope diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py new file mode 100644 index 000000000..8f99d1e30 --- /dev/null +++ b/dbgpt/core/interface/variables.py @@ -0,0 +1,678 @@ +"""Variables Module.""" + +import base64 +import dataclasses +import hashlib +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.executor_utils import ( + DefaultExecutorFactory, + blocking_func_to_async, + blocking_func_to_async_no_executor, +) + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageInterface, + StorageItem, +) + +_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE" + +BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows" +BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes" +BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables" +BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets" +BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms" +BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings" +BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers" +BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources" +BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents" +BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces" + + +class Encryption(ABC): + """Encryption interface.""" + + name: str = "__abstract__" + + @abstractmethod + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data.""" + + @abstractmethod + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data.""" + + +def _generate_key_from_password( + password: bytes, salt: Optional[Union[str, bytes]] = None +): + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + if salt is None: + salt = os.urandom(16) + elif isinstance(salt, str): + salt = salt.encode() + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(password)) + return key, salt + + +class FernetEncryption(Encryption): + """Fernet encryption. + + A symmetric encryption algorithm that uses the same key for both encryption and + decryption which is powered by the cryptography library. + """ + + name = "fernet" + + def __init__(self, key: Optional[bytes] = None): + """Initialize the fernet encryption.""" + if key is not None and isinstance(key, str): + key = key.encode() + try: + from cryptography.fernet import Fernet + except ImportError: + raise ImportError( + "cryptography is required for encryption, please install by running " + "`pip install cryptography`" + ) + if key is None: + key = Fernet.generate_key() + self.key = key + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt. + + Args: + data (str): The data to encrypt. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The encrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + encrypted_secret = fernet.encrypt(data.encode()).decode() + return encrypted_secret + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt. + + Args: + encrypted_data (str): The encrypted data. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The decrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + return fernet.decrypt(encrypted_data.encode()).decode() + + +class SimpleEncryption(Encryption): + """Simple implementation of encryption. + + A simple encryption algorithm that uses a key to XOR the data. + """ + + name = "simple" + + def __init__(self, key: Optional[str] = None): + """Initialize the simple encryption.""" + if key is None: + key = base64.b64encode(os.urandom(32)).decode() + self.key = key + + def _derive_key(self, salt: str) -> bytes: + return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000) + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt.""" + key = self._derive_key(salt) + encrypted = bytes( + x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1)) + ) + return base64.b64encode(encrypted).decode() + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt.""" + key = self._derive_key(salt) + data = base64.b64decode(encrypted_data) + decrypted = bytes( + x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1)) + ) + return decrypted.decode() + + +@dataclasses.dataclass +class VariablesIdentifier(ResourceIdentifier): + """The variables identifier.""" + + identifier_split: str = dataclasses.field(default="@", init=False) + + key: str + name: str + scope: str = "global" + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + + def __post_init__(self): + """Post init method.""" + if not self.key or not self.name or not self.scope: + raise ValueError("Key, name, and scope are required.") + + if any( + self.identifier_split in key + for key in [ + self.key, + self.name, + self.scope, + self.scope_key, + self.sys_code, + self.user_name, + ] + if key is not None + ): + raise ValueError( + f"identifier_split {self.identifier_split} is not allowed in " + f"key, name, scope, scope_key, sys_code, user_name." + ) + + @property + def str_identifier(self) -> str: + """Return the string identifier of the identifier.""" + return self.identifier_split.join( + key or "" + for key in [ + self.key, + self.name, + self.scope, + self.scope_key, + self.sys_code, + self.user_name, + ] + ) + + def to_dict(self) -> Dict: + """Convert the identifier to a dict. + + Returns: + Dict: The dict of the identifier. + """ + return { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } + + @classmethod + def from_str_identifier( + cls, str_identifier: str, identifier_split: str = "@" + ) -> "VariablesIdentifier": + """Create a VariablesIdentifier from a string identifier. + + Args: + str_identifier (str): The string identifier. + identifier_split (str): The identifier split. + + Returns: + VariablesIdentifier: The VariablesIdentifier. + """ + keys = str_identifier.split(identifier_split) + if not keys: + raise ValueError("Invalid string identifier.") + if len(keys) < 2: + raise ValueError("Invalid string identifier, must have name") + if len(keys) < 3: + raise ValueError("Invalid string identifier, must have scope") + + return cls( + key=keys[0], + name=keys[1], + scope=keys[2], + scope_key=keys[3] if len(keys) > 3 else None, + sys_code=keys[4] if len(keys) > 4 else None, + user_name=keys[5] if len(keys) > 5 else None, + ) + + +@dataclasses.dataclass +class StorageVariables(StorageItem): + """The storage variables.""" + + key: str + name: str + label: str + value: Any + category: Literal["common", "secret"] = "common" + scope: str = "global" + value_type: Optional[str] = None + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + encryption_method: Optional[str] = None + salt: Optional[str] = None + enabled: int = 1 + + _identifier: VariablesIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = VariablesIdentifier( + key=self.key, + name=self.name, + scope=self.scope, + scope_key=self.scope_key, + sys_code=self.sys_code, + user_name=self.user_name, + ) + if not self.value_type: + self.value_type = type(self.value).__name__ + + @property + def identifier(self) -> ResourceIdentifier: + """Return the identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge with another storage variables.""" + if not isinstance(other, StorageVariables): + raise ValueError(f"Cannot merge with {type(other)}") + self.from_object(other) + + def to_dict(self) -> Dict: + """Convert the storage variables to a dict. + + Returns: + Dict: The dict of the storage variables. + """ + return { + **self._identifier.to_dict(), + "label": self.label, + "value": self.value, + "value_type": self.value_type, + "category": self.category, + "encryption_method": self.encryption_method, + "salt": self.salt, + } + + def from_object(self, other: "StorageVariables") -> None: + """Copy the values from another storage variables object.""" + self.label = other.label + self.value = other.value + self.value_type = other.value_type + self.category = other.category + self.scope = other.scope + self.scope_key = other.scope_key + self.sys_code = other.sys_code + self.user_name = other.user_name + self.encryption_method = other.encryption_method + self.salt = other.salt + + @classmethod + def from_identifier( + cls, + identifier: VariablesIdentifier, + value: Any, + value_type: str, + label: str = "", + category: Literal["common", "secret"] = "common", + encryption_method: Optional[str] = None, + salt: Optional[str] = None, + ) -> "StorageVariables": + """Copy the values from an identifier.""" + return cls( + key=identifier.key, + name=identifier.name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=identifier.scope, + scope_key=identifier.scope_key, + sys_code=identifier.sys_code, + user_name=identifier.user_name, + encryption_method=encryption_method, + salt=salt, + ) + + +class VariablesProvider(BaseComponent, ABC): + """The variables provider interface.""" + + name = ComponentType.VARIABLES_PROVIDER.value + + @abstractmethod + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + + @abstractmethod + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + + @abstractmethod + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key.""" + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key async.""" + raise NotImplementedError("Current variables provider does not support async.") + + def support_async(self) -> bool: + """Whether the variables provider support async.""" + return False + + +class VariablesPlaceHolder: + """The variables place holder.""" + + def __init__( + self, + param_name: str, + full_key: str, + value_type: str, + default_value: Any = _EMPTY_DEFAULT_VALUE, + ): + """Initialize the variables place holder.""" + self.param_name = param_name + self.full_key = full_key + self.value_type = value_type + self.default_value = default_value + + def parse(self, variables_provider: VariablesProvider) -> Any: + """Parse the variables.""" + value = variables_provider.get(self.full_key, self.default_value) + if value: + return self._cast_to_type(value) + else: + return value + + def _cast_to_type(self, value: Any) -> Any: + if self.value_type == "str": + return str(value) + elif self.value_type == "int": + return int(value) + elif self.value_type == "float": + return float(value) + elif self.value_type == "bool": + if value.lower() in ["true", "1"]: + return True + elif value.lower() in ["false", "0"]: + return False + else: + return bool(value) + else: + return value + + def __repr__(self): + """Return the representation of the variables place holder.""" + return ( + f"" + ) + + +class StorageVariablesProvider(VariablesProvider): + """The storage variables provider.""" + + def __init__( + self, + storage: Optional[StorageInterface] = None, + encryption: Optional[Encryption] = None, + system_app: Optional[SystemApp] = None, + key: Optional[str] = None, + ): + """Initialize the storage variables provider.""" + if storage is None: + storage = InMemoryStorage() + self.system_app = system_app + self.encryption = encryption or SimpleEncryption(key) + + self.storage = storage + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the storage variables provider.""" + self.system_app = system_app + + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + key = VariablesIdentifier.from_str_identifier(full_key) + variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables) + if variable is None: + if default_value == _EMPTY_DEFAULT_VALUE: + raise ValueError(f"Variable {full_key} not found") + return default_value + variable.value = self.deserialize_value(variable.value) + if ( + variable.value is not None + and variable.category == "secret" + and variable.encryption_method + and variable.salt + ): + variable.value = self.encryption.decrypt(variable.value, variable.salt) + return variable.value + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + if variables_item.category == "secret": + salt = base64.b64encode(os.urandom(16)).decode() + variables_item.value = self.encryption.encrypt( + str(variables_item.value), salt + ) + variables_item.encryption_method = self.encryption.name + variables_item.salt = salt + # Replace value to a json serializable object + variables_item.value = self.serialize_value(variables_item.value) + + self.storage.save_or_update(variables_item) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage.""" + # Try to get builtin variables + is_builtin, builtin_variables = self._get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + variables = self.storage.query( + QuerySpec( + conditions={ + "key": key, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + "enabled": 1, + } + ), + StorageVariables, + ) + for variable in variables: + variable.value = self.deserialize_value(variable.value) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage async.""" + # Try to get builtin variables + is_builtin, builtin_variables = await self._async_get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + executor_factory: Optional[ + DefaultExecutorFactory + ] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None) + if executor_factory: + return await blocking_func_to_async( + executor_factory.create(), + self.get_variables, + key, + scope, + scope_key, + sys_code, + user_name, + ) + else: + return await blocking_func_to_async_no_executor( + self.get_variables, key, scope, scope_key, sys_code, user_name + ) + + def _get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + return True, provider.get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + async def _async_get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + if not provider.support_async(): + return False, [] + return True, await provider.async_get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + @classmethod + def serialize_value(cls, value: Any) -> str: + """Serialize the value.""" + value_dict = {"value": value} + return json.dumps(value_dict, ensure_ascii=False) + + @classmethod + def deserialize_value(cls, value: str) -> Any: + """Deserialize the value.""" + value_dict = json.loads(value) + return value_dict["value"] + + +class BuiltinVariablesProvider(VariablesProvider, ABC): + """The builtin variables provider. + + You can implement this class to provide builtin variables. Such LLMs, agents, + datasource, knowledge base, etc. + """ + + name = "dbgpt_variables_builtin" + + def __init__(self, system_app: Optional[SystemApp] = None): + """Initialize the builtin variables provider.""" + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the builtin variables provider.""" + self.system_app = system_app + + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support get.") + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support save.") diff --git a/dbgpt/serve/core/__init__.py b/dbgpt/serve/core/__init__.py index 090288128..31edd5d6c 100644 --- a/dbgpt/serve/core/__init__.py +++ b/dbgpt/serve/core/__init__.py @@ -1,7 +1,11 @@ +from typing import Any + from dbgpt.serve.core.config import BaseServeConfig from dbgpt.serve.core.schemas import Result, add_exception_handler from dbgpt.serve.core.serve import BaseServe from dbgpt.serve.core.service import BaseService +from dbgpt.util.executor_utils import BlockingFunction, DefaultExecutorFactory +from dbgpt.util.executor_utils import blocking_func_to_async as _blocking_func_to_async __ALL__ = [ "Result", @@ -10,3 +14,11 @@ __ALL__ = [ "BaseService", "BaseServe", ] + + +async def blocking_func_to_async( + system_app, func: BlockingFunction, *args, **kwargs +) -> Any: + """Run a potentially blocking function within an executor.""" + executor = DefaultExecutorFactory.get_instance(system_app).create() + return await _blocking_func_to_async(executor, func, *args, **kwargs) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index c2c62b95f..001617d6c 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -7,12 +7,19 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata from dbgpt.core.awel.flow.flow_factory import FlowCategory -from dbgpt.serve.core import Result +from dbgpt.serve.core import Result, blocking_func_to_async from dbgpt.util import PaginationResult from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse +from ..service.variables_service import VariablesService +from .schemas import ( + RefreshNodeRequest, + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) router = APIRouter() @@ -23,7 +30,12 @@ global_system_app: Optional[SystemApp] = None def get_service() -> Service: """Get the service instance""" - return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + return Service.get_instance(global_system_app) + + +def get_variable_service() -> VariablesService: + """Get the service instance""" + return VariablesService.get_instance(global_system_app) get_bearer_token = HTTPBearer(auto_error=False) @@ -261,16 +273,80 @@ async def refresh_nodes(refresh_request: RefreshNodeRequest): """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - new_metadata = _OPERATOR_REGISTRY.refresh( - key=refresh_request.id, - is_operator=refresh_request.flow_type == "operator", - request=refresh_request.refresh, + # Make sure the variables provider is initialized + _ = get_variable_service().variables_provider + + new_metadata = await _OPERATOR_REGISTRY.refresh( + refresh_request.id, + refresh_request.flow_type == "operator", + refresh_request.refresh, + "http", + global_system_app, ) return Result.succ(new_metadata) +@router.post( + "/variables", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def create_variables( + variables_request: VariablesRequest, +) -> Result[VariablesResponse]: + """Create a new Variables entity + + Args: + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().create, variables_request + ) + return Result.succ(res) + + +@router.put( + "/variables/{v_id}", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def update_variables( + v_id: int, variables_request: VariablesRequest +) -> Result[VariablesResponse]: + """Update a Variables entity + + Args: + v_id (int): The variable id + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().update, v_id, variables_request + ) + return Result.succ(res) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" + from .variables_provider import ( + BuiltinAllSecretVariablesProvider, + BuiltinAllVariablesProvider, + BuiltinEmbeddingsVariablesProvider, + BuiltinFlowVariablesProvider, + BuiltinLLMVariablesProvider, + BuiltinNodeVariablesProvider, + ) + global global_system_app system_app.register(Service) + system_app.register(VariablesService) + system_app.register(BuiltinFlowVariablesProvider) + system_app.register(BuiltinNodeVariablesProvider) + system_app.register(BuiltinAllVariablesProvider) + system_app.register(BuiltinAllSecretVariablesProvider) + system_app.register(BuiltinLLMVariablesProvider) + system_app.register(BuiltinEmbeddingsVariablesProvider) global_system_app = system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 2daa8f581..e63d3e6ce 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Any, List, Literal, Optional from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel.flow.flow_factory import FlowPanel @@ -17,6 +17,69 @@ class ServerResponse(FlowPanel): model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") +class VariablesRequest(BaseModel): + """Variable request model. + + For creating a new variable in the DB-GPT. + """ + + key: str = Field( + ..., + description="The key of the variable to create", + examples=["dbgpt.model.openai.api_key"], + ) + name: str = Field( + ..., + description="The name of the variable to create", + examples=["my_first_openai_key"], + ) + label: str = Field( + ..., + description="The label of the variable to create", + examples=["My First OpenAI Key"], + ) + value: Any = Field( + ..., description="The value of the variable to create", examples=["1234567890"] + ) + value_type: Literal["str", "int", "float", "bool"] = Field( + "str", + description="The type of the value of the variable to create", + examples=["str", "int", "float", "bool"], + ) + category: Literal["common", "secret"] = Field( + ..., + description="The category of the variable to create", + examples=["common"], + ) + scope: str = Field( + ..., + description="The scope of the variable to create", + examples=["global"], + ) + scope_key: Optional[str] = Field( + ..., + description="The scope key of the variable to create", + examples=["dbgpt"], + ) + enabled: Optional[bool] = Field( + True, + description="Whether the variable is enabled", + examples=[True], + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") + + +class VariablesResponse(VariablesRequest): + """Variable response model.""" + + id: int = Field( + ..., + description="The id of the variable", + examples=[1], + ) + + class RefreshNodeRequest(BaseModel): """Flow response model""" diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py new file mode 100644 index 000000000..4728f80e6 --- /dev/null +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -0,0 +1,260 @@ +from typing import List, Literal, Optional + +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + BuiltinVariablesProvider, + StorageVariables, +) + +from ..service.service import Service +from .endpoints import get_service, get_variable_service +from .schemas import ServerResponse + + +class BuiltinFlowVariablesProvider(BuiltinVariablesProvider): + """Builtin flow variables provider. + + Provide all flows by variables "${dbgpt.core.flow.flows}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOWS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + service: Service = get_service() + page_result = service.get_list_by_page( + { + "user_name": user_name, + "sys_code": sys_code, + }, + 1, + 1000, + ) + flows: List[ServerResponse] = page_result.items + variables = [] + for flow in flows: + variables.append( + StorageVariables( + key=key, + name=flow.name, + label=flow.label, + value=flow.uid, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinNodeVariablesProvider(BuiltinVariablesProvider): + """Builtin node variables provider. + + Provide all nodes by variables "${dbgpt.core.flow.nodes}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOW_NODES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + metadata_list = _OPERATOR_REGISTRY.metadata_list() + variables = [] + for metadata in metadata_list: + variables.append( + StorageVariables( + key=key, + name=metadata["name"], + label=metadata["label"], + value=metadata["id"], + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinAllVariablesProvider(BuiltinVariablesProvider): + """Builtin all variables provider. + + Provide all variables by variables "${dbgpt.core.variables}" + """ + + name = BUILTIN_VARIABLES_CORE_VARIABLES + + def _get_variables_from_db( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + category: Literal["common", "secret"] = "common", + ) -> List[StorageVariables]: + storage_variables = get_variable_service().list_all_variables(category) + variables = [] + for var in storage_variables: + variables.append( + StorageVariables( + key=key, + name=var.name, + label=var.label, + value=var.value, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables. + + TODO: Return all builtin variables + """ + return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name) + + +class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider): + """Builtin all secret variables provider. + + Provide all secret variables by variables "${dbgpt.core.secrets}" + """ + + name = BUILTIN_VARIABLES_CORE_SECRETS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return self._get_variables_from_db( + key, scope, scope_key, sys_code, user_name, "secret" + ) + + +class BuiltinLLMVariablesProvider(BuiltinVariablesProvider): + """Builtin LLM variables provider. + + Provide all LLM variables by variables "${dbgpt.core.llmv}" + """ + + name = BUILTIN_VARIABLES_CORE_LLMS + + def support_async(self) -> bool: + """Whether the dynamic options support async.""" + return True + + async def _get_models( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + expect_worker_type: str = "llm", + ) -> List[StorageVariables]: + from dbgpt.model.cluster.controller.controller import BaseModelController + + controller = BaseModelController.get_instance(self.system_app) + models = await controller.get_all_instances(healthy_only=True) + model_dict = {} + for model in models: + worker_name, worker_type = model.model_name.split("@") + if expect_worker_type == worker_type: + model_dict[worker_name] = model + variables = [] + for worker_name, model in model_dict.items(): + variables.append( + StorageVariables( + key=key, + name=worker_name, + label=worker_name, + value=worker_name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models(key, scope, scope_key, sys_code, user_name) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + raise NotImplementedError( + "Not implemented get variables sync, please use async_get_variables" + ) + + +class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider): + """Builtin embeddings variables provider. + + Provide all embeddings variables by variables "${dbgpt.core.embeddings}" + """ + + name = BUILTIN_VARIABLES_CORE_EMBEDDINGS + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models( + key, scope, scope_key, sys_code, user_name, "text2vec" + ) diff --git a/dbgpt/serve/flow/config.py b/dbgpt/serve/flow/config.py index 97eea7478..0cc35667d 100644 --- a/dbgpt/serve/flow/config.py +++ b/dbgpt/serve/flow/config.py @@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow" SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow" SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow." SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service" # Database table name SERVER_APP_TABLE_NAME = "dbgpt_serve_flow" +SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables" @dataclass @@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig): load_dbgpts_interval: int = field( default=5, metadata={"help": "Interval to load dbgpts from installed packages"} ) + encrypt_key: Optional[str] = field( + default=None, metadata={"help": "The key to encrypt the data"} + ) diff --git a/dbgpt/serve/flow/models/models.py b/dbgpt/serve/flow/models/models.py index ea4c7f3ea..c4166147d 100644 --- a/dbgpt/serve/flow/models/models.py +++ b/dbgpt/serve/flow/models/models.py @@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint from dbgpt._private.pydantic import model_to_dict from dbgpt.core.awel.flow.flow_factory import State +from dbgpt.core.interface.variables import StorageVariablesProvider from dbgpt.storage.metadata import BaseDao, Model from dbgpt.storage.metadata._base_dao import QUERY_SPEC -from ..api.schemas import ServeRequest, ServerResponse -from ..config import SERVER_APP_TABLE_NAME, ServeConfig +from ..api.schemas import ( + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) +from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig class ServeEntity(Model): @@ -74,6 +80,56 @@ class ServeEntity(Model): return editable is None or editable == 0 +class VariablesEntity(Model): + __tablename__ = SERVER_APP_VARIABLES_TABLE_NAME + + id = Column(Integer, primary_key=True, comment="Auto increment id") + key = Column(String(128), index=True, nullable=False, comment="Variable key") + name = Column(String(128), index=True, nullable=True, comment="Variable name") + label = Column(String(128), nullable=True, comment="Variable label") + value = Column(Text, nullable=True, comment="Variable value, JSON format") + value_type = Column( + String(32), + nullable=True, + comment="Variable value type(string, int, float, bool)", + ) + category = Column( + String(32), + default="common", + nullable=True, + comment="Variable category(common or secret)", + ) + encryption_method = Column( + String(32), + nullable=True, + comment="Variable encryption method(fernet, simple, rsa, aes)", + ) + salt = Column(String(128), nullable=True, comment="Variable salt") + scope = Column( + String(32), + default="global", + nullable=True, + comment="Variable scope(global,flow,app,agent,datasource,flow:uid," + "flow:dag_name,agent:agent_name) etc", + ) + scope_key = Column( + String(256), + nullable=True, + comment="Variable scope key, default is empty, for scope is 'flow:uid', " + "the scope_key is uid of flow", + ) + enabled = Column( + Integer, + default=1, + nullable=True, + comment="Variable enabled, 0: disabled, 1: enabled", + ) + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): """The DAO class for Flow""" @@ -222,3 +278,108 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): session.merge(entry) session.commit() return self.get_one(query_request) + + +class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]): + """The DAO class for Variables""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request( + self, request: Union[VariablesRequest, Dict[str, Any]] + ) -> VariablesEntity: + """Convert the request to an entity + + Args: + request (Union[VariablesRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + model_to_dict(request) if isinstance(request, VariablesRequest) else request + ) + value = StorageVariablesProvider.serialize_value(request_dict.get("value")) + enabled = 1 if request_dict.get("enabled", True) else 0 + new_dict = { + "key": request_dict.get("key"), + "name": request_dict.get("name"), + "label": request_dict.get("label"), + "value": value, + "value_type": request_dict.get("value_type"), + "category": request_dict.get("category"), + "encryption_method": request_dict.get("encryption_method"), + "salt": request_dict.get("salt"), + "scope": request_dict.get("scope"), + "scope_key": request_dict.get("scope_key"), + "enabled": enabled, + "user_name": request_dict.get("user_name"), + "sys_code": request_dict.get("sys_code"), + } + entity = VariablesEntity(**new_dict) + return entity + + def to_request(self, entity: VariablesEntity) -> VariablesRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + enabled = entity.enabled == 1 + return VariablesRequest( + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + ) + + def to_response(self, entity: VariablesEntity) -> VariablesResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") + gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + enabled = entity.enabled == 1 + return VariablesResponse( + id=entity.id, + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + gmt_created=gmt_created_str, + gmt_modified=gmt_modified_str, + ) diff --git a/dbgpt/serve/flow/models/variables_adapter.py b/dbgpt/serve/flow/models/variables_adapter.py new file mode 100644 index 000000000..d8a1ef1e0 --- /dev/null +++ b/dbgpt/serve/flow/models/variables_adapter.py @@ -0,0 +1,69 @@ +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.storage import StorageItemAdapter +from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier + +from .models import VariablesEntity + + +class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]): + """Variables adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: StorageVariables) -> VariablesEntity: + """Convert to storage format.""" + return VariablesEntity( + key=item.key, + name=item.name, + label=item.label, + value=item.value, + value_type=item.value_type, + category=item.category, + encryption_method=item.encryption_method, + salt=item.salt, + scope=item.scope, + scope_key=item.scope_key, + sys_code=item.sys_code, + user_name=item.user_name, + ) + + def from_storage_format(self, model: VariablesEntity) -> StorageVariables: + """Convert from storage format.""" + return StorageVariables( + key=model.key, + name=model.name, + label=model.label, + value=model.value, + value_type=model.value_type, + category=model.category, + encryption_method=model.encryption_method, + salt=model.salt, + scope=model.scope, + scope_key=model.scope_key, + sys_code=model.sys_code, + user_name=model.user_name, + ) + + def get_query_for_identifier( + self, + storage_format: Type[VariablesEntity], + resource_id: VariablesIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + query_obj = session.query(VariablesEntity) + for key, value in resource_id.to_dict().items(): + if value is None: + continue + query_obj = query_obj.filter(getattr(VariablesEntity, key) == value) + + # enabled must be True + query_obj = query_obj.filter(VariablesEntity.enabled == 1) + return query_obj diff --git a/dbgpt/serve/flow/serve.py b/dbgpt/serve/flow/serve.py index 126841e57..a27e3d28f 100644 --- a/dbgpt/serve/flow/serve.py +++ b/dbgpt/serve/flow/serve.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union from sqlalchemy import URL from dbgpt.component import SystemApp +from dbgpt.core.interface.variables import VariablesProvider from dbgpt.serve.core import BaseServe from dbgpt.storage.metadata import DatabaseManager @@ -40,6 +41,8 @@ class Serve(BaseServe): system_app, api_prefix, api_tags, db_url_or_db, try_create_tables ) self._db_manager: Optional[DatabaseManager] = None + self._variables_provider: Optional[VariablesProvider] = None + self._serve_config: Optional[ServeConfig] = None def init_app(self, system_app: SystemApp): if self._app_has_initiated: @@ -62,5 +65,37 @@ class Serve(BaseServe): def before_start(self): """Called before the start of the application.""" - # TODO: Your code here + from dbgpt.core.interface.variables import ( + FernetEncryption, + StorageVariablesProvider, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.models import ServeEntity, VariablesEntity + from .models.variables_adapter import VariablesAdapter + self._db_manager = self.create_or_get_db_manager() + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + storage_adapter = VariablesAdapter() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + VariablesEntity, + storage_adapter, + serializer, + ) + self._variables_provider = StorageVariablesProvider( + storage=storage, + encryption=FernetEncryption(self._serve_config.encrypt_key), + system_app=self._system_app, + ) + + @property + def variables_provider(self): + """Get the variables provider of the serve app with db storage""" + return self._variables_provider diff --git a/dbgpt/serve/flow/service/variables_service.py b/dbgpt/serve/flow/service/variables_service.py new file mode 100644 index 000000000..4b79d27db --- /dev/null +++ b/dbgpt/serve/flow/service/variables_service.py @@ -0,0 +1,148 @@ +from typing import List, Optional + +from dbgpt import SystemApp +from dbgpt.core.interface.variables import StorageVariables, VariablesProvider +from dbgpt.serve.core import BaseService + +from ..api.schemas import VariablesRequest, VariablesResponse +from ..config import ( + SERVE_CONFIG_KEY_PREFIX, + SERVE_VARIABLES_SERVICE_COMPONENT_NAME, + ServeConfig, +) +from ..models.models import VariablesDao, VariablesEntity + + +class VariablesService( + BaseService[VariablesEntity, VariablesRequest, VariablesResponse] +): + """Variables service""" + + name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: VariablesDao = dao + + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or VariablesDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> VariablesDao: + """Returns the internal DAO.""" + return self._dao + + @property + def variables_provider(self) -> VariablesProvider: + """Returns the internal VariablesProvider. + + Returns: + VariablesProvider: The internal VariablesProvider + """ + variables_provider = VariablesProvider.get_instance( + self._system_app, default_component=None + ) + if variables_provider: + return variables_provider + else: + from ..serve import Serve + + variables_provider = Serve.get_instance(self._system_app).variables_provider + self._system_app.register_instance(variables_provider) + return variables_provider + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def create(self, request: VariablesRequest) -> VariablesResponse: + """Create a new entity + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def update(self, _: int, request: VariablesRequest) -> VariablesResponse: + """Update variables. + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + ) + exist_value = self.variables_provider.get( + variables.identifier.str_identifier, None + ) + if exist_value is None: + raise ValueError( + f"Variable {variables.identifier.str_identifier} not found" + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def list_all_variables(self, category: str = "common") -> List[VariablesResponse]: + """List all variables.""" + return self.dao.get_list({"enabled": True, "category": category}) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index fc8d9a5c4..7a38f8d4b 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -1,5 +1,6 @@ """Some UI components for the AWEL flow.""" +import json import logging from typing import List, Optional @@ -10,9 +11,18 @@ from dbgpt.core.awel.flow import ( OperatorCategory, OptionValue, Parameter, + VariablesDynamicOptions, ViewMetadata, ui, ) +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, +) logger = logging.getLogger(__name__) @@ -717,3 +727,157 @@ class ExampleFlowRefreshOperator(MapOperator[str, str]): user_name, self.recent_time, ) + + +class ExampleFlowVariablesOperator(MapOperator[str, str]): + """An example flow operator that includes a variables option.""" + + metadata = ViewMetadata( + label="Example Variables Operator", + name="example_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + Parameter.build_from( + "Builtin Flows", + "builtin_flow", + type=str, + placeholder="Please select the builtin flows", + description="The builtin flows to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOWS, + ), + ), + Parameter.build_from( + "Builtin Flow Nodes", + "builtin_flow_node", + type=str, + placeholder="Please select the builtin flow nodes", + description="The builtin flow nodes to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOW_NODES, + ), + ), + Parameter.build_from( + "Builtin Variables", + "builtin_variable", + type=str, + placeholder="Please select the builtin variables", + description="The builtin variables to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_VARIABLES, + ), + ), + Parameter.build_from( + "Builtin Secrets", + "builtin_secret", + type=str, + placeholder="Please select the builtin secrets", + description="The builtin secrets to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_SECRETS, + ), + ), + Parameter.build_from( + "Builtin LLMs", + "builtin_llm", + type=str, + placeholder="Please select the builtin LLMs", + description="The builtin LLMs to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_LLMS, + ), + ), + Parameter.build_from( + "Builtin Embeddings", + "builtin_embedding", + type=str, + placeholder="Please select the builtin embeddings", + description="The builtin embeddings to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_EMBEDDINGS, + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__( + self, + openai_api_key: str, + model: str, + builtin_flow: str, + builtin_flow_node: str, + builtin_variable: str, + builtin_secret: str, + builtin_llm: str, + builtin_embedding: str, + **kwargs, + ): + super().__init__(**kwargs) + self.openai_api_key = openai_api_key + self.model = model + self.builtin_flow = builtin_flow + self.builtin_flow_node = builtin_flow_node + self.builtin_variable = builtin_variable + self.builtin_secret = builtin_secret + self.builtin_llm = builtin_llm + self.builtin_embedding = builtin_embedding + + async def map(self, user_name: str) -> str: + """Map the user name to the model.""" + dict_dict = { + "openai_api_key": self.openai_api_key, + "model": self.model, + "builtin_flow": self.builtin_flow, + "builtin_flow_node": self.builtin_flow_node, + "builtin_variable": self.builtin_variable, + "builtin_secret": self.builtin_secret, + "builtin_llm": self.builtin_llm, + "builtin_embedding": self.builtin_embedding, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) diff --git a/setup.py b/setup.py index cbe5592ce..a968892df 100644 --- a/setup.py +++ b/setup.py @@ -498,6 +498,8 @@ def core_requires(): "GitPython", # For AWEL dag visualization, graphviz is a small package, also we can move it to default. "graphviz", + # For security + "cryptography", ]