feat(core): Support dag scope variables

This commit is contained in:
Fangyin Cheng
2024-08-22 11:28:40 +08:00
parent 97b57fb071
commit 22c3d73fe7
26 changed files with 2265 additions and 46 deletions

View File

@@ -303,6 +303,7 @@ class Config(metaclass=Singleton):
)
# global dbgpt api key
self.API_KEYS = os.getenv("API_KEYS", None)
self.ENCRYPT_KEY = os.getenv("ENCRYPT_KEY", "your_secret_key")
# Non-streaming scene retries
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(

View File

@@ -1,5 +1,6 @@
"""Import all models to make sure they are registered with SQLAlchemy.
"""
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
@@ -11,6 +12,7 @@ from dbgpt.serve.agent.app.recommend_question.recommend_question import (
from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginEntity
from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubEntity
from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity
from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity
from dbgpt.storage.chat_history.chat_history_db import (
@@ -32,4 +34,5 @@ _MODELS = [
ModelInstanceEntity,
FlowServeEntity,
RecommendQuestionEntity,
FlowVariableEntity,
]

View File

@@ -7,6 +7,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
if cfg.API_KEYS:
system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS)
if cfg.ENCRYPT_KEY:
system_app.config.set("dbgpt.app.global.encrypt_key", cfg.ENCRYPT_KEY)
# ################################ Prompt Serve Register Begin ######################################
from dbgpt.serve.prompt.serve import (

View File

@@ -89,6 +89,7 @@ class ComponentType(str, Enum):
CONNECTOR_MANAGER = "dbgpt_connector_manager"
AGENT_MANAGER = "dbgpt_agent_manager"
RESOURCE_MANAGER = "dbgpt_resource_manager"
VARIABLES_PROVIDER = "dbgpt_variables_provider"
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"

View File

@@ -11,7 +11,18 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Union,
cast,
)
from dbgpt.component import SystemApp
@@ -23,6 +34,9 @@ logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
if TYPE_CHECKING:
from ...interface.variables import VariablesProvider
def _is_async_context():
try:
@@ -128,6 +142,8 @@ class DAGVar:
# The executor for current DAG, this is used run some sync tasks in async DAG
_executor: Optional[Executor] = None
_variables_provider: Optional["VariablesProvider"] = None
@classmethod
def enter_dag(cls, dag) -> None:
"""Enter a DAG context.
@@ -221,6 +237,24 @@ class DAGVar:
"""
cls._executor = executor
@classmethod
def get_variables_provider(cls) -> Optional["VariablesProvider"]:
"""Get the current variables provider.
Returns:
Optional[VariablesProvider]: The current variables provider
"""
return cls._variables_provider
@classmethod
def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None:
"""Set the current variables provider.
Args:
variables_provider (VariablesProvider): The variables provider to set
"""
cls._variables_provider = variables_provider
class DAGLifecycle:
"""The lifecycle of DAG."""

View File

@@ -7,6 +7,7 @@ from ..util.parameter_util import ( # noqa: F401
BaseDynamicOptions,
FunctionDynamicOptions,
OptionValue,
VariablesDynamicOptions,
)
from .base import ( # noqa: F401
IOField,
@@ -35,4 +36,5 @@ __ALL__ = [
"IOField",
"BaseDynamicOptions",
"FunctionDynamicOptions",
"VariablesDynamicOptions",
]

View File

@@ -6,7 +6,7 @@ import inspect
from abc import ABC
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast
from dbgpt._private.pydantic import (
BaseModel,
@@ -15,12 +15,14 @@ from dbgpt._private.pydantic import (
model_to_dict,
model_validator,
)
from dbgpt.component import SystemApp
from dbgpt.core.awel.util.parameter_util import (
BaseDynamicOptions,
OptionValue,
RefreshOptionRequest,
)
from dbgpt.core.interface.serialization import Serializable
from dbgpt.util.executor_utils import DefaultExecutorFactory, blocking_func_to_async
from .exceptions import FlowMetadataException, FlowParameterMetadataException
from .ui import UIComponent
@@ -490,11 +492,19 @@ class Parameter(TypeMetadata, Serializable):
dict_value["ui"] = self.ui.to_dict()
return dict_value
def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict:
async def refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> Dict:
"""Refresh the options of the parameter.
Args:
request (RefreshOptionRequest): The request to refresh the options.
trigger (Literal["default", "http"], optional): The trigger type.
Defaults to "default".
system_app (Optional[SystemApp], optional): The system app.
Returns:
Dict: The response.
@@ -503,7 +513,7 @@ class Parameter(TypeMetadata, Serializable):
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
values = self.options.refresh(request)
values = self.options.refresh(request, trigger, system_app)
dict_value["options"] = [value.to_dict() for value in values]
else:
dict_value["options"] = [value.to_dict() for value in self.options]
@@ -793,18 +803,56 @@ class BaseMetadata(BaseResource):
]
return dict_value
def refresh(self, request: List[RefreshOptionRequest]) -> Dict:
"""Refresh the metadata."""
async def refresh(
self,
request: List[RefreshOptionRequest],
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> Dict:
"""Refresh the metadata.
Args:
request (List[RefreshOptionRequest]): The refresh request
trigger (Literal["default", "http"]): The trigger type, how to trigger
the refresh
system_app (Optional[SystemApp]): The system app
"""
executor = DefaultExecutorFactory.get_instance(system_app).create()
name_to_request = {req.name: req for req in request}
parameter_requests = {
parameter.name: name_to_request.get(parameter.name)
for parameter in self.parameters
}
dict_value = self.to_dict()
dict_value["parameters"] = [
parameter.refresh(parameter_requests.get(parameter.name))
for parameter in self.parameters
]
dict_value = model_to_dict(self, exclude={"parameters"})
parameters = []
for parameter in self.parameters:
parameter_dict = parameter.to_dict()
parameter_request = parameter_requests.get(parameter.name)
if not parameter.options:
options = None
elif isinstance(parameter.options, BaseDynamicOptions):
options_obj = parameter.options
if options_obj.support_async(system_app, parameter_request):
values = await options_obj.async_refresh(
parameter_request, trigger, system_app
)
else:
values = await blocking_func_to_async(
executor,
options_obj.refresh,
parameter_request,
trigger,
system_app,
)
options = [value.to_dict() for value in values]
else:
options = [value.to_dict() for value in self.options]
parameter_dict["options"] = options
parameters.append(parameter_dict)
dict_value["parameters"] = parameters
return dict_value
@@ -1090,14 +1138,23 @@ class FlowRegistry:
"""Get the metadata list."""
return [item.metadata.to_dict() for item in self._registry.values()]
def refresh(
self, key: str, is_operator: bool, request: List[RefreshOptionRequest]
async def refresh(
self,
key: str,
is_operator: bool,
request: List[RefreshOptionRequest],
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> Dict:
"""Refresh the metadata."""
if is_operator:
return _get_operator_class(key).metadata.refresh(request) # type: ignore
return await _get_operator_class(key).metadata.refresh( # type: ignore
request, trigger, system_app
)
else:
return _get_resource_class(key).metadata.refresh(request)
return await _get_resource_class(key).metadata.refresh(
request, trigger, system_app
)
_OPERATOR_REGISTRY: FlowRegistry = FlowRegistry()

View File

@@ -71,7 +71,7 @@ class PanelEditorMixin(BaseModel):
class UIComponent(RefreshableMixin, Serializable, BaseModel):
"""UI component."""
class UIAttribute(StatusMixin, BaseModel):
class UIAttribute(BaseModel):
"""Base UI attribute."""
disabled: bool = Field(
@@ -106,7 +106,7 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
class UISelect(UIComponent):
"""Select component."""
class UIAttribute(UIComponent.UIAttribute):
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
"""Select attribute."""
show_search: bool = Field(
@@ -138,7 +138,7 @@ class UISelect(UIComponent):
class UICascader(UIComponent):
"""Cascader component."""
class UIAttribute(UIComponent.UIAttribute):
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
"""Cascader attribute."""
show_search: bool = Field(
@@ -178,7 +178,7 @@ class UICheckbox(UIComponent):
class UIDatePicker(UIComponent):
"""Date picker component."""
class UIAttribute(UIComponent.UIAttribute):
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
"""Date picker attribute."""
placement: Optional[
@@ -199,7 +199,7 @@ class UIDatePicker(UIComponent):
class UIInput(UIComponent):
"""Input component."""
class UIAttribute(UIComponent.UIAttribute):
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
"""Input attribute."""
prefix: Optional[str] = Field(
@@ -216,7 +216,7 @@ class UIInput(UIComponent):
None,
description="Whether to show count",
)
maxlength: Optional[int] = Field(
max_length: Optional[int] = Field(
None,
description="The maximum length of the input",
)
@@ -294,7 +294,7 @@ class UISlider(UIComponent):
class UITimePicker(UIComponent):
"""Time picker component."""
class UIAttribute(UIComponent.UIAttribute):
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
"""Time picker attribute."""
format: Optional[str] = Field(
@@ -377,15 +377,20 @@ class UIUpload(UIComponent):
)
class UIVariableInput(UIInput):
"""Variable input component."""
class UIVariablesInput(UIInput):
"""Variables input component."""
ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore
ui_type: Literal["variable"] = Field("variables", frozen=True) # type: ignore
key: str = Field(..., description="The key of the variable")
key_type: Literal["common", "secret"] = Field(
"common",
description="The type of the key",
)
scope: str = Field("global", description="The scope of the variables")
scope_key: Optional[str] = Field(
None,
description="The key of the scope",
)
refresh: Optional[bool] = Field(
True,
description="Whether to enable the refresh",
@@ -396,7 +401,7 @@ class UIVariableInput(UIInput):
self._check_options(parameter_dict.get("options", {}))
class UIPasswordInput(UIVariableInput):
class UIPasswordInput(UIVariablesInput):
"""Password input component."""
ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore

View File

@@ -2,10 +2,12 @@
import asyncio
import functools
import logging
from abc import ABC, ABCMeta, abstractmethod
from contextvars import ContextVar
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
@@ -29,6 +31,11 @@ from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data
if TYPE_CHECKING:
from ...interface.variables import VariablesProvider
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=FunctionType)
CALL_DATA = Union[Dict[str, Any], Any]
@@ -92,6 +99,9 @@ class BaseOperatorMeta(ABCMeta):
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
variables_provider = (
kwargs.get("variables_provider") or DAGVar.get_variables_provider()
)
if not executor:
if system_app:
executor = system_app.get_component(
@@ -102,14 +112,24 @@ class BaseOperatorMeta(ABCMeta):
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not variables_provider:
from ...interface.variables import VariablesProvider
if system_app:
variables_provider = system_app.get_component(
ComponentType.VARIABLES_PROVIDER,
VariablesProvider,
default_component=None,
)
else:
from ...interface.variables import StorageVariablesProvider
variables_provider = StorageVariablesProvider()
DAGVar.set_variables_provider(variables_provider)
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
# for arg in sig_cache.parameters:
# if arg not in kwargs:
# kwargs[arg] = default_args[arg]
if not kwargs.get("dag"):
kwargs["dag"] = dag
if not kwargs.get("task_id"):
@@ -120,6 +140,8 @@ class BaseOperatorMeta(ABCMeta):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
if not kwargs.get("variables_provider"):
kwargs["variables_provider"] = variables_provider
real_obj = func(self, *args, **kwargs)
return real_obj
@@ -150,6 +172,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
dag: Optional[DAG] = None,
runner: Optional[WorkflowRunner] = None,
can_skip_in_branch: bool = True,
variables_provider: Optional["VariablesProvider"] = None,
**kwargs,
) -> None:
"""Create a BaseOperator with an optional workflow runner.
@@ -171,6 +194,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch
self._variables_provider = variables_provider
@property
def current_dag_context(self) -> DAGContext:
@@ -199,6 +223,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not task_log_id:
raise ValueError(f"The task log ID can't be empty, current node {self}")
CURRENT_DAG_CONTEXT.set(dag_ctx)
# Resolve variables
await self._resolve_variables(dag_ctx)
return await self._do_run(dag_ctx)
@abstractmethod
@@ -347,6 +373,21 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Check if the operator can be skipped in the branch."""
return self._can_skip_in_branch
async def _resolve_variables(self, _: DAGContext):
from ...interface.variables import VariablesPlaceHolder
if not self._variables_provider:
return
for attr, value in self.__dict__.items():
if isinstance(value, VariablesPlaceHolder):
resolved_value = await self.blocking_func_to_async(
value.parse, self._variables_provider
)
logger.debug(
f"Resolve variable {attr} with value {resolved_value} for {self}"
)
setattr(self, attr, resolved_value)
def initialize_runner(runner: WorkflowRunner):
"""Initialize the default runner."""

View File

@@ -0,0 +1,111 @@
from contextlib import asynccontextmanager
import pytest
import pytest_asyncio
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
VariablesPlaceHolder,
)
from .. import DAG, DAGVar, InputOperator, MapOperator, SimpleInputSource
class VariablesOperator(MapOperator[str, str]):
def __init__(self, int_var: int, str_var: str, secret: str, **kwargs):
super().__init__(**kwargs)
self._int_var = int_var
self._str_var = str_var
self._secret = secret
async def map(self, x: str) -> str:
return (
f"x: {x}, int_var: {self._int_var}, str_var: {self._str_var}, "
f"secret: {self._secret}"
)
@pytest.fixture
def default_dag():
with DAG("test_dag") as dag:
input_node = InputOperator(input_source=SimpleInputSource.from_callable())
map_node = MapOperator(lambda x: x * 2)
input_node >> map_node
return dag
@asynccontextmanager
async def _create_variables(**kwargs):
variables_provider = StorageVariablesProvider()
DAGVar.set_variables_provider(variables_provider)
vars = kwargs.get("vars")
variables = {}
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():
key = param_var.get("key")
value = param_var.get("value")
value_type = param_var.get("value_type")
category = param_var.get("category", "common")
id = VariablesIdentifier.from_str_identifier(key)
variables_provider.save(
StorageVariables.from_identifier(
id, value, value_type, label="", category=category
)
)
variables[param_key] = VariablesPlaceHolder(param_key, key, value_type)
else:
raise ValueError("vars is required.")
with DAG("simple_dag") as dag:
map_node = VariablesOperator(**variables)
yield map_node
@pytest_asyncio.fixture
async def variables_node(request):
param = getattr(request, "param", {})
async with _create_variables(**param) as node:
yield node
@pytest.mark.asyncio
async def test_default_dag(default_dag: DAG):
leaf_node = default_dag.leaf_nodes[0]
res = await leaf_node.call(2)
assert res == 4
@pytest.mark.asyncio
@pytest.mark.parametrize(
"variables_node",
[
(
{
"vars": {
"int_var": {
"key": "int_key@my_int_var@global",
"value": 0,
"value_type": "int",
},
"str_var": {
"key": "str_key@my_str_var@global",
"value": "1",
"value_type": "str",
},
"secret": {
"key": "secret_key@my_secret_var@global",
"value": "2131sdsdf",
"value_type": "str",
"category": "secret",
},
}
}
),
],
indirect=["variables_node"],
)
async def test_input_nodes(variables_node: VariablesOperator):
res = await variables_node.call("test")
assert res == "x: test, int_var: 0, str_var: 1, secret: 2131sdsdf"

View File

@@ -2,9 +2,10 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
from dbgpt._private.pydantic import BaseModel, Field, model_validator
from dbgpt.component import SystemApp
from dbgpt.core.interface.serialization import Serializable
_DEFAULT_DYNAMIC_REGISTRY = {}
@@ -29,6 +30,21 @@ class RefreshOptionRequest(BaseModel):
depends: Optional[List[RefreshOptionDependency]] = Field(
None, description="The depends of the refresh config"
)
variables_key: Optional[str] = Field(
None, description="The variables key to refresh"
)
variables_scope: Optional[str] = Field(
None, description="The variables scope to refresh"
)
variables_scope_key: Optional[str] = Field(
None, description="The variables scope key to refresh"
)
variables_sys_code: Optional[str] = Field(
None, description="The system code to refresh"
)
variables_user_name: Optional[str] = Field(
None, description="The user name to refresh"
)
class OptionValue(Serializable, BaseModel):
@@ -49,13 +65,57 @@ class OptionValue(Serializable, BaseModel):
class BaseDynamicOptions(Serializable, BaseModel, ABC):
"""The base dynamic options."""
def support_async(
self,
system_app: Optional[SystemApp] = None,
request: Optional[RefreshOptionRequest] = None,
) -> bool:
"""Whether the dynamic options support async.
Args:
system_app (Optional[SystemApp]): The system app
request (Optional[RefreshOptionRequest]): The refresh request
Returns:
bool: Whether the dynamic options support async
"""
return False
def option_values(self) -> List[OptionValue]:
"""Return the option values of the parameter."""
return self.refresh(None)
@abstractmethod
def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]:
"""Refresh the dynamic options."""
def refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> List[OptionValue]:
"""Refresh the dynamic options.
Args:
request (Optional[RefreshOptionRequest]): The refresh request
trigger (Literal["default", "http"]): The trigger type, how to trigger
the refresh
system_app (Optional[SystemApp]): The system app
"""
async def async_refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> List[OptionValue]:
"""Refresh the dynamic options async.
Args:
request (Optional[RefreshOptionRequest]): The refresh request
trigger (Literal["default", "http"]): The trigger type, how to trigger
the refresh
system_app (Optional[SystemApp]): The system app
"""
raise NotImplementedError("The dynamic options does not support async.")
class FunctionDynamicOptions(BaseDynamicOptions):
@@ -68,7 +128,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
..., description="The unique id of the function to generate the dynamic options"
)
def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]:
def refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> List[OptionValue]:
"""Refresh the dynamic options."""
if not request or not request.depends:
return self.func()
@@ -96,6 +161,109 @@ class FunctionDynamicOptions(BaseDynamicOptions):
return {"func_id": self.func_id}
class VariablesDynamicOptions(BaseDynamicOptions):
"""The variables dynamic options."""
def support_async(
self,
system_app: Optional[SystemApp] = None,
request: Optional[RefreshOptionRequest] = None,
) -> bool:
"""Whether the dynamic options support async."""
if not system_app or not request or not request.variables_key:
return False
from ...interface.variables import BuiltinVariablesProvider
provider: BuiltinVariablesProvider = system_app.get_component(
request.variables_key,
component_type=BuiltinVariablesProvider,
default_component=None,
)
if not provider:
return False
return provider.support_async()
def refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> List[OptionValue]:
"""Refresh the dynamic options."""
if (
trigger == "default"
or not request
or not request.variables_key
or not request.variables_scope
):
# Only refresh when trigger is http and request is not None
return []
if not system_app:
raise ValueError("The system app is required when refresh the variables.")
from ...interface.variables import VariablesProvider
vp: VariablesProvider = VariablesProvider.get_instance(system_app)
variables = vp.get_variables(
key=request.variables_key,
scope=request.variables_scope,
scope_key=request.variables_scope_key,
sys_code=request.variables_sys_code,
user_name=request.variables_user_name,
)
options = []
for var in variables:
options.append(
OptionValue(
label=var.label,
name=var.name,
value=var.value,
)
)
return options
async def async_refresh(
self,
request: Optional[RefreshOptionRequest] = None,
trigger: Literal["default", "http"] = "default",
system_app: Optional[SystemApp] = None,
) -> List[OptionValue]:
"""Refresh the dynamic options async."""
if (
trigger == "default"
or not request
or not request.variables_key
or not request.variables_scope
):
return []
if not system_app:
raise ValueError("The system app is required when refresh the variables.")
from ...interface.variables import VariablesProvider
vp: VariablesProvider = VariablesProvider.get_instance(system_app)
variables = await vp.async_get_variables(
key=request.variables_key,
scope=request.variables_scope,
scope_key=request.variables_scope_key,
sys_code=request.variables_sys_code,
user_name=request.variables_user_name,
)
options = []
for var in variables:
options.append(
OptionValue(
label=var.label,
name=var.name,
value=var.value,
)
)
return options
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
return {"key": self.key}
def _generate_unique_id(func: Callable) -> str:
if func.__name__ == "<lambda>":
func_id = f"lambda_{inspect.getfile(func)}_{inspect.getsourcelines(func)}"

View File

@@ -3,13 +3,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.i18n_utils import _
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ..awel.flow import Parameter, ResourceCategory, register_resource
@PublicAPI(stability="beta")
class ResourceIdentifier(Serializable, ABC):

View File

@@ -0,0 +1,114 @@
import base64
import os
from cryptography.fernet import Fernet
from ..variables import (
FernetEncryption,
InMemoryStorage,
SimpleEncryption,
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
def test_fernet_encryption():
key = Fernet.generate_key()
encryption = FernetEncryption(key)
new_encryption = FernetEncryption(key)
data = "test_data"
salt = "test_salt"
encrypted_data = encryption.encrypt(data, salt)
assert encrypted_data != data
decrypted_data = encryption.decrypt(encrypted_data, salt)
assert decrypted_data == data
assert decrypted_data == new_encryption.decrypt(encrypted_data, salt)
def test_simple_encryption():
key = base64.b64encode(os.urandom(32)).decode()
encryption = SimpleEncryption(key)
data = "test_data"
salt = "test_salt"
encrypted_data = encryption.encrypt(data, salt)
assert encrypted_data != data
decrypted_data = encryption.decrypt(encrypted_data, salt)
assert decrypted_data == data
def test_storage_variables_provider():
storage = InMemoryStorage()
encryption = SimpleEncryption()
provider = StorageVariablesProvider(storage, encryption)
full_key = "key@name@global"
value = "secret_value"
value_type = "str"
label = "test_label"
id = VariablesIdentifier.from_str_identifier(full_key)
provider.save(
StorageVariables.from_identifier(
id, value, value_type, label, category="secret"
)
)
loaded_variable_value = provider.get(full_key)
assert loaded_variable_value == value
def test_variables_identifier():
full_key = "key@name@global@scope_key@sys_code@user_name"
identifier = VariablesIdentifier.from_str_identifier(full_key)
assert identifier.key == "key"
assert identifier.name == "name"
assert identifier.scope == "global"
assert identifier.scope_key == "scope_key"
assert identifier.sys_code == "sys_code"
assert identifier.user_name == "user_name"
str_identifier = identifier.str_identifier
assert str_identifier == full_key
def test_storage_variables():
key = "test_key"
name = "test_name"
label = "test_label"
value = "test_value"
value_type = "str"
category = "common"
scope = "global"
storage_variable = StorageVariables(
key=key,
name=name,
label=label,
value=value,
value_type=value_type,
category=category,
scope=scope,
)
assert storage_variable.key == key
assert storage_variable.name == name
assert storage_variable.label == label
assert storage_variable.value == value
assert storage_variable.value_type == value_type
assert storage_variable.category == category
assert storage_variable.scope == scope
dict_representation = storage_variable.to_dict()
assert dict_representation["key"] == key
assert dict_representation["name"] == name
assert dict_representation["label"] == label
assert dict_representation["value"] == value
assert dict_representation["value_type"] == value_type
assert dict_representation["category"] == category
assert dict_representation["scope"] == scope

View File

@@ -0,0 +1,678 @@
"""Variables Module."""
import base64
import dataclasses
import hashlib
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.executor_utils import (
DefaultExecutorFactory,
blocking_func_to_async,
blocking_func_to_async_no_executor,
)
from .storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE"
BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows"
BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes"
BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces"
class Encryption(ABC):
"""Encryption interface."""
name: str = "__abstract__"
@abstractmethod
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data."""
@abstractmethod
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data."""
def _generate_key_from_password(
password: bytes, salt: Optional[Union[str, bytes]] = None
):
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
if salt is None:
salt = os.urandom(16)
elif isinstance(salt, str):
salt = salt.encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password))
return key, salt
class FernetEncryption(Encryption):
"""Fernet encryption.
A symmetric encryption algorithm that uses the same key for both encryption and
decryption which is powered by the cryptography library.
"""
name = "fernet"
def __init__(self, key: Optional[bytes] = None):
"""Initialize the fernet encryption."""
if key is not None and isinstance(key, str):
key = key.encode()
try:
from cryptography.fernet import Fernet
except ImportError:
raise ImportError(
"cryptography is required for encryption, please install by running "
"`pip install cryptography`"
)
if key is None:
key = Fernet.generate_key()
self.key = key
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data with the salt.
Args:
data (str): The data to encrypt.
salt (str): The salt to use, which is used to derive the key.
Returns:
str: The encrypted data.
"""
from cryptography.fernet import Fernet
key, salt = _generate_key_from_password(self.key, salt)
fernet = Fernet(key)
encrypted_secret = fernet.encrypt(data.encode()).decode()
return encrypted_secret
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data with the salt.
Args:
encrypted_data (str): The encrypted data.
salt (str): The salt to use, which is used to derive the key.
Returns:
str: The decrypted data.
"""
from cryptography.fernet import Fernet
key, salt = _generate_key_from_password(self.key, salt)
fernet = Fernet(key)
return fernet.decrypt(encrypted_data.encode()).decode()
class SimpleEncryption(Encryption):
"""Simple implementation of encryption.
A simple encryption algorithm that uses a key to XOR the data.
"""
name = "simple"
def __init__(self, key: Optional[str] = None):
"""Initialize the simple encryption."""
if key is None:
key = base64.b64encode(os.urandom(32)).decode()
self.key = key
def _derive_key(self, salt: str) -> bytes:
return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000)
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data with the salt."""
key = self._derive_key(salt)
encrypted = bytes(
x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1))
)
return base64.b64encode(encrypted).decode()
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data with the salt."""
key = self._derive_key(salt)
data = base64.b64decode(encrypted_data)
decrypted = bytes(
x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1))
)
return decrypted.decode()
@dataclasses.dataclass
class VariablesIdentifier(ResourceIdentifier):
"""The variables identifier."""
identifier_split: str = dataclasses.field(default="@", init=False)
key: str
name: str
scope: str = "global"
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
def __post_init__(self):
"""Post init method."""
if not self.key or not self.name or not self.scope:
raise ValueError("Key, name, and scope are required.")
if any(
self.identifier_split in key
for key in [
self.key,
self.name,
self.scope,
self.scope_key,
self.sys_code,
self.user_name,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in "
f"key, name, scope, scope_key, sys_code, user_name."
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key or ""
for key in [
self.key,
self.name,
self.scope,
self.scope_key,
self.sys_code,
self.user_name,
]
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"key": self.key,
"name": self.name,
"scope": self.scope,
"scope_key": self.scope_key,
"sys_code": self.sys_code,
"user_name": self.user_name,
}
@classmethod
def from_str_identifier(
cls, str_identifier: str, identifier_split: str = "@"
) -> "VariablesIdentifier":
"""Create a VariablesIdentifier from a string identifier.
Args:
str_identifier (str): The string identifier.
identifier_split (str): The identifier split.
Returns:
VariablesIdentifier: The VariablesIdentifier.
"""
keys = str_identifier.split(identifier_split)
if not keys:
raise ValueError("Invalid string identifier.")
if len(keys) < 2:
raise ValueError("Invalid string identifier, must have name")
if len(keys) < 3:
raise ValueError("Invalid string identifier, must have scope")
return cls(
key=keys[0],
name=keys[1],
scope=keys[2],
scope_key=keys[3] if len(keys) > 3 else None,
sys_code=keys[4] if len(keys) > 4 else None,
user_name=keys[5] if len(keys) > 5 else None,
)
@dataclasses.dataclass
class StorageVariables(StorageItem):
"""The storage variables."""
key: str
name: str
label: str
value: Any
category: Literal["common", "secret"] = "common"
scope: str = "global"
value_type: Optional[str] = None
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
encryption_method: Optional[str] = None
salt: Optional[str] = None
enabled: int = 1
_identifier: VariablesIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = VariablesIdentifier(
key=self.key,
name=self.name,
scope=self.scope,
scope_key=self.scope_key,
sys_code=self.sys_code,
user_name=self.user_name,
)
if not self.value_type:
self.value_type = type(self.value).__name__
@property
def identifier(self) -> ResourceIdentifier:
"""Return the identifier."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge with another storage variables."""
if not isinstance(other, StorageVariables):
raise ValueError(f"Cannot merge with {type(other)}")
self.from_object(other)
def to_dict(self) -> Dict:
"""Convert the storage variables to a dict.
Returns:
Dict: The dict of the storage variables.
"""
return {
**self._identifier.to_dict(),
"label": self.label,
"value": self.value,
"value_type": self.value_type,
"category": self.category,
"encryption_method": self.encryption_method,
"salt": self.salt,
}
def from_object(self, other: "StorageVariables") -> None:
"""Copy the values from another storage variables object."""
self.label = other.label
self.value = other.value
self.value_type = other.value_type
self.category = other.category
self.scope = other.scope
self.scope_key = other.scope_key
self.sys_code = other.sys_code
self.user_name = other.user_name
self.encryption_method = other.encryption_method
self.salt = other.salt
@classmethod
def from_identifier(
cls,
identifier: VariablesIdentifier,
value: Any,
value_type: str,
label: str = "",
category: Literal["common", "secret"] = "common",
encryption_method: Optional[str] = None,
salt: Optional[str] = None,
) -> "StorageVariables":
"""Copy the values from an identifier."""
return cls(
key=identifier.key,
name=identifier.name,
label=label,
value=value,
value_type=value_type,
category=category,
scope=identifier.scope,
scope_key=identifier.scope_key,
sys_code=identifier.sys_code,
user_name=identifier.user_name,
encryption_method=encryption_method,
salt=salt,
)
class VariablesProvider(BaseComponent, ABC):
"""The variables provider interface."""
name = ComponentType.VARIABLES_PROVIDER.value
@abstractmethod
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
) -> Any:
"""Query variables from storage."""
@abstractmethod
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
@abstractmethod
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get variables by key."""
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get variables by key async."""
raise NotImplementedError("Current variables provider does not support async.")
def support_async(self) -> bool:
"""Whether the variables provider support async."""
return False
class VariablesPlaceHolder:
"""The variables place holder."""
def __init__(
self,
param_name: str,
full_key: str,
value_type: str,
default_value: Any = _EMPTY_DEFAULT_VALUE,
):
"""Initialize the variables place holder."""
self.param_name = param_name
self.full_key = full_key
self.value_type = value_type
self.default_value = default_value
def parse(self, variables_provider: VariablesProvider) -> Any:
"""Parse the variables."""
value = variables_provider.get(self.full_key, self.default_value)
if value:
return self._cast_to_type(value)
else:
return value
def _cast_to_type(self, value: Any) -> Any:
if self.value_type == "str":
return str(value)
elif self.value_type == "int":
return int(value)
elif self.value_type == "float":
return float(value)
elif self.value_type == "bool":
if value.lower() in ["true", "1"]:
return True
elif value.lower() in ["false", "0"]:
return False
else:
return bool(value)
else:
return value
def __repr__(self):
"""Return the representation of the variables place holder."""
return (
f"<VariablesPlaceHolder "
f"{self.param_name} {self.full_key} {self.value_type}>"
)
class StorageVariablesProvider(VariablesProvider):
"""The storage variables provider."""
def __init__(
self,
storage: Optional[StorageInterface] = None,
encryption: Optional[Encryption] = None,
system_app: Optional[SystemApp] = None,
key: Optional[str] = None,
):
"""Initialize the storage variables provider."""
if storage is None:
storage = InMemoryStorage()
self.system_app = system_app
self.encryption = encryption or SimpleEncryption(key)
self.storage = storage
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the storage variables provider."""
self.system_app = system_app
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
) -> Any:
"""Query variables from storage."""
key = VariablesIdentifier.from_str_identifier(full_key)
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
if variable is None:
if default_value == _EMPTY_DEFAULT_VALUE:
raise ValueError(f"Variable {full_key} not found")
return default_value
variable.value = self.deserialize_value(variable.value)
if (
variable.value is not None
and variable.category == "secret"
and variable.encryption_method
and variable.salt
):
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return variable.value
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
if variables_item.category == "secret":
salt = base64.b64encode(os.urandom(16)).decode()
variables_item.value = self.encryption.encrypt(
str(variables_item.value), salt
)
variables_item.encryption_method = self.encryption.name
variables_item.salt = salt
# Replace value to a json serializable object
variables_item.value = self.serialize_value(variables_item.value)
self.storage.save_or_update(variables_item)
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Query variables from storage."""
# Try to get builtin variables
is_builtin, builtin_variables = self._get_builtins_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
if is_builtin:
return builtin_variables
variables = self.storage.query(
QuerySpec(
conditions={
"key": key,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
"enabled": 1,
}
),
StorageVariables,
)
for variable in variables:
variable.value = self.deserialize_value(variable.value)
return variables
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Query variables from storage async."""
# Try to get builtin variables
is_builtin, builtin_variables = await self._async_get_builtins_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
if is_builtin:
return builtin_variables
executor_factory: Optional[
DefaultExecutorFactory
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
if executor_factory:
return await blocking_func_to_async(
executor_factory.create(),
self.get_variables,
key,
scope,
scope_key,
sys_code,
user_name,
)
else:
return await blocking_func_to_async_no_executor(
self.get_variables, key, scope, scope_key, sys_code, user_name
)
def _get_builtins_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> Tuple[bool, List[StorageVariables]]:
"""Get builtin variables."""
if self.system_app is None:
return False, []
provider: BuiltinVariablesProvider = self.system_app.get_component(
key,
component_type=BuiltinVariablesProvider,
default_component=None,
)
if not provider:
return False, []
return True, provider.get_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
async def _async_get_builtins_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> Tuple[bool, List[StorageVariables]]:
"""Get builtin variables."""
if self.system_app is None:
return False, []
provider: BuiltinVariablesProvider = self.system_app.get_component(
key,
component_type=BuiltinVariablesProvider,
default_component=None,
)
if not provider:
return False, []
if not provider.support_async():
return False, []
return True, await provider.async_get_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
@classmethod
def serialize_value(cls, value: Any) -> str:
"""Serialize the value."""
value_dict = {"value": value}
return json.dumps(value_dict, ensure_ascii=False)
@classmethod
def deserialize_value(cls, value: str) -> Any:
"""Deserialize the value."""
value_dict = json.loads(value)
return value_dict["value"]
class BuiltinVariablesProvider(VariablesProvider, ABC):
"""The builtin variables provider.
You can implement this class to provide builtin variables. Such LLMs, agents,
datasource, knowledge base, etc.
"""
name = "dbgpt_variables_builtin"
def __init__(self, system_app: Optional[SystemApp] = None):
"""Initialize the builtin variables provider."""
self.system_app = system_app
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the builtin variables provider."""
self.system_app = system_app
def get(
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
) -> Any:
"""Query variables from storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support get.")
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support save.")

View File

@@ -1,7 +1,11 @@
from typing import Any
from dbgpt.serve.core.config import BaseServeConfig
from dbgpt.serve.core.schemas import Result, add_exception_handler
from dbgpt.serve.core.serve import BaseServe
from dbgpt.serve.core.service import BaseService
from dbgpt.util.executor_utils import BlockingFunction, DefaultExecutorFactory
from dbgpt.util.executor_utils import blocking_func_to_async as _blocking_func_to_async
__ALL__ = [
"Result",
@@ -10,3 +14,11 @@ __ALL__ = [
"BaseService",
"BaseServe",
]
async def blocking_func_to_async(
system_app, func: BlockingFunction, *args, **kwargs
) -> Any:
"""Run a potentially blocking function within an executor."""
executor = DefaultExecutorFactory.get_instance(system_app).create()
return await _blocking_func_to_async(executor, func, *args, **kwargs)

View File

@@ -7,12 +7,19 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.serve.core import Result
from dbgpt.serve.core import Result, blocking_func_to_async
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse
from ..service.variables_service import VariablesService
from .schemas import (
RefreshNodeRequest,
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
router = APIRouter()
@@ -23,7 +30,12 @@ global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
return Service.get_instance(global_system_app)
def get_variable_service() -> VariablesService:
"""Get the service instance"""
return VariablesService.get_instance(global_system_app)
get_bearer_token = HTTPBearer(auto_error=False)
@@ -261,16 +273,80 @@ async def refresh_nodes(refresh_request: RefreshNodeRequest):
"""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
new_metadata = _OPERATOR_REGISTRY.refresh(
key=refresh_request.id,
is_operator=refresh_request.flow_type == "operator",
request=refresh_request.refresh,
# Make sure the variables provider is initialized
_ = get_variable_service().variables_provider
new_metadata = await _OPERATOR_REGISTRY.refresh(
refresh_request.id,
refresh_request.flow_type == "operator",
refresh_request.refresh,
"http",
global_system_app,
)
return Result.succ(new_metadata)
@router.post(
"/variables",
response_model=Result[VariablesResponse],
dependencies=[Depends(check_api_key)],
)
async def create_variables(
variables_request: VariablesRequest,
) -> Result[VariablesResponse]:
"""Create a new Variables entity
Args:
variables_request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
res = await blocking_func_to_async(
global_system_app, get_variable_service().create, variables_request
)
return Result.succ(res)
@router.put(
"/variables/{v_id}",
response_model=Result[VariablesResponse],
dependencies=[Depends(check_api_key)],
)
async def update_variables(
v_id: int, variables_request: VariablesRequest
) -> Result[VariablesResponse]:
"""Update a Variables entity
Args:
v_id (int): The variable id
variables_request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
res = await blocking_func_to_async(
global_system_app, get_variable_service().update, v_id, variables_request
)
return Result.succ(res)
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
from .variables_provider import (
BuiltinAllSecretVariablesProvider,
BuiltinAllVariablesProvider,
BuiltinEmbeddingsVariablesProvider,
BuiltinFlowVariablesProvider,
BuiltinLLMVariablesProvider,
BuiltinNodeVariablesProvider,
)
global global_system_app
system_app.register(Service)
system_app.register(VariablesService)
system_app.register(BuiltinFlowVariablesProvider)
system_app.register(BuiltinNodeVariablesProvider)
system_app.register(BuiltinAllVariablesProvider)
system_app.register(BuiltinAllSecretVariablesProvider)
system_app.register(BuiltinLLMVariablesProvider)
system_app.register(BuiltinEmbeddingsVariablesProvider)
global_system_app = system_app

View File

@@ -1,4 +1,4 @@
from typing import List, Literal
from typing import Any, List, Literal, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core.awel.flow.flow_factory import FlowPanel
@@ -17,6 +17,69 @@ class ServerResponse(FlowPanel):
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
class VariablesRequest(BaseModel):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
key: str = Field(
...,
description="The key of the variable to create",
examples=["dbgpt.model.openai.api_key"],
)
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
label: str = Field(
...,
description="The label of the variable to create",
examples=["My First OpenAI Key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
)
value_type: Literal["str", "int", "float", "bool"] = Field(
"str",
description="The type of the value of the variable to create",
examples=["str", "int", "float", "bool"],
)
category: Literal["common", "secret"] = Field(
...,
description="The category of the variable to create",
examples=["common"],
)
scope: str = Field(
...,
description="The scope of the variable to create",
examples=["global"],
)
scope_key: Optional[str] = Field(
...,
description="The scope key of the variable to create",
examples=["dbgpt"],
)
enabled: Optional[bool] = Field(
True,
description="Whether the variable is enabled",
examples=[True],
)
user_name: Optional[str] = Field(None, description="User name")
sys_code: Optional[str] = Field(None, description="System code")
class VariablesResponse(VariablesRequest):
"""Variable response model."""
id: int = Field(
...,
description="The id of the variable",
examples=[1],
)
class RefreshNodeRequest(BaseModel):
"""Flow response model"""

View File

@@ -0,0 +1,260 @@
from typing import List, Literal, Optional
from dbgpt.core.interface.variables import (
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
BUILTIN_VARIABLES_CORE_FLOW_NODES,
BUILTIN_VARIABLES_CORE_FLOWS,
BUILTIN_VARIABLES_CORE_LLMS,
BUILTIN_VARIABLES_CORE_SECRETS,
BUILTIN_VARIABLES_CORE_VARIABLES,
BuiltinVariablesProvider,
StorageVariables,
)
from ..service.service import Service
from .endpoints import get_service, get_variable_service
from .schemas import ServerResponse
class BuiltinFlowVariablesProvider(BuiltinVariablesProvider):
"""Builtin flow variables provider.
Provide all flows by variables "${dbgpt.core.flow.flows}"
"""
name = BUILTIN_VARIABLES_CORE_FLOWS
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
service: Service = get_service()
page_result = service.get_list_by_page(
{
"user_name": user_name,
"sys_code": sys_code,
},
1,
1000,
)
flows: List[ServerResponse] = page_result.items
variables = []
for flow in flows:
variables.append(
StorageVariables(
key=key,
name=flow.name,
label=flow.label,
value=flow.uid,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
class BuiltinNodeVariablesProvider(BuiltinVariablesProvider):
"""Builtin node variables provider.
Provide all nodes by variables "${dbgpt.core.flow.nodes}"
"""
name = BUILTIN_VARIABLES_CORE_FLOW_NODES
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
metadata_list = _OPERATOR_REGISTRY.metadata_list()
variables = []
for metadata in metadata_list:
variables.append(
StorageVariables(
key=key,
name=metadata["name"],
label=metadata["label"],
value=metadata["id"],
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
class BuiltinAllVariablesProvider(BuiltinVariablesProvider):
"""Builtin all variables provider.
Provide all variables by variables "${dbgpt.core.variables}"
"""
name = BUILTIN_VARIABLES_CORE_VARIABLES
def _get_variables_from_db(
self,
key: str,
scope: str,
scope_key: Optional[str],
sys_code: Optional[str],
user_name: Optional[str],
category: Literal["common", "secret"] = "common",
) -> List[StorageVariables]:
storage_variables = get_variable_service().list_all_variables(category)
variables = []
for var in storage_variables:
variables.append(
StorageVariables(
key=key,
name=var.name,
label=var.label,
value=var.value,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables.
TODO: Return all builtin variables
"""
return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name)
class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider):
"""Builtin all secret variables provider.
Provide all secret variables by variables "${dbgpt.core.secrets}"
"""
name = BUILTIN_VARIABLES_CORE_SECRETS
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return self._get_variables_from_db(
key, scope, scope_key, sys_code, user_name, "secret"
)
class BuiltinLLMVariablesProvider(BuiltinVariablesProvider):
"""Builtin LLM variables provider.
Provide all LLM variables by variables "${dbgpt.core.llmv}"
"""
name = BUILTIN_VARIABLES_CORE_LLMS
def support_async(self) -> bool:
"""Whether the dynamic options support async."""
return True
async def _get_models(
self,
key: str,
scope: str,
scope_key: Optional[str],
sys_code: Optional[str],
user_name: Optional[str],
expect_worker_type: str = "llm",
) -> List[StorageVariables]:
from dbgpt.model.cluster.controller.controller import BaseModelController
controller = BaseModelController.get_instance(self.system_app)
models = await controller.get_all_instances(healthy_only=True)
model_dict = {}
for model in models:
worker_name, worker_type = model.model_name.split("@")
if expect_worker_type == worker_type:
model_dict[worker_name] = model
variables = []
for worker_name, model in model_dict.items():
variables.append(
StorageVariables(
key=key,
name=worker_name,
label=worker_name,
value=worker_name,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return await self._get_models(key, scope, scope_key, sys_code, user_name)
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
raise NotImplementedError(
"Not implemented get variables sync, please use async_get_variables"
)
class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider):
"""Builtin embeddings variables provider.
Provide all embeddings variables by variables "${dbgpt.core.embeddings}"
"""
name = BUILTIN_VARIABLES_CORE_EMBEDDINGS
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return await self._get_models(
key, scope, scope_key, sys_code, user_name, "text2vec"
)

View File

@@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_flow"
SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables"
@dataclass
@@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig):
load_dbgpts_interval: int = field(
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
)
encrypt_key: Optional[str] = field(
default=None, metadata={"help": "The key to encrypt the data"}
)

View File

@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import model_to_dict
from dbgpt.core.awel.flow.flow_factory import State
from dbgpt.core.interface.variables import StorageVariablesProvider
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
from ..api.schemas import (
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
class ServeEntity(Model):
@@ -74,6 +80,56 @@ class ServeEntity(Model):
return editable is None or editable == 0
class VariablesEntity(Model):
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
key = Column(String(128), index=True, nullable=False, comment="Variable key")
name = Column(String(128), index=True, nullable=True, comment="Variable name")
label = Column(String(128), nullable=True, comment="Variable label")
value = Column(Text, nullable=True, comment="Variable value, JSON format")
value_type = Column(
String(32),
nullable=True,
comment="Variable value type(string, int, float, bool)",
)
category = Column(
String(32),
default="common",
nullable=True,
comment="Variable category(common or secret)",
)
encryption_method = Column(
String(32),
nullable=True,
comment="Variable encryption method(fernet, simple, rsa, aes)",
)
salt = Column(String(128), nullable=True, comment="Variable salt")
scope = Column(
String(32),
default="global",
nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow:uid,"
"flow:dag_name,agent:agent_name) etc",
)
scope_key = Column(
String(256),
nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow:uid', "
"the scope_key is uid of flow",
)
enabled = Column(
Integer,
default=1,
nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled",
)
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Flow"""
@@ -222,3 +278,108 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
session.merge(entry)
session.commit()
return self.get_one(query_request)
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
"""The DAO class for Variables"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(
self, request: Union[VariablesRequest, Dict[str, Any]]
) -> VariablesEntity:
"""Convert the request to an entity
Args:
request (Union[VariablesRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, VariablesRequest) else request
)
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
enabled = 1 if request_dict.get("enabled", True) else 0
new_dict = {
"key": request_dict.get("key"),
"name": request_dict.get("name"),
"label": request_dict.get("label"),
"value": value,
"value_type": request_dict.get("value_type"),
"category": request_dict.get("category"),
"encryption_method": request_dict.get("encryption_method"),
"salt": request_dict.get("salt"),
"scope": request_dict.get("scope"),
"scope_key": request_dict.get("scope_key"),
"enabled": enabled,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
}
entity = VariablesEntity(**new_dict)
return entity
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
enabled = entity.enabled == 1
return VariablesRequest(
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
enabled = entity.enabled == 1
return VariablesResponse(
id=entity.id,
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

View File

@@ -0,0 +1,69 @@
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.storage import StorageItemAdapter
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
from .models import VariablesEntity
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
"""Variables adapter.
Convert between storage format and database model.
"""
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
"""Convert to storage format."""
return VariablesEntity(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
value_type=item.value_type,
category=item.category,
encryption_method=item.encryption_method,
salt=item.salt,
scope=item.scope,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
)
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
"""Convert from storage format."""
return StorageVariables(
key=model.key,
name=model.name,
label=model.label,
value=model.value,
value_type=model.value_type,
category=model.category,
encryption_method=model.encryption_method,
salt=model.salt,
scope=model.scope,
scope_key=model.scope_key,
sys_code=model.sys_code,
user_name=model.user_name,
)
def get_query_for_identifier(
self,
storage_format: Type[VariablesEntity],
resource_id: VariablesIdentifier,
**kwargs,
):
"""Get query for identifier."""
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(VariablesEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
# enabled must be True
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
return query_obj

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.core.interface.variables import VariablesProvider
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
@@ -40,6 +41,8 @@ class Serve(BaseServe):
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
self._variables_provider: Optional[VariablesProvider] = None
self._serve_config: Optional[ServeConfig] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
@@ -62,5 +65,37 @@ class Serve(BaseServe):
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
from dbgpt.core.interface.variables import (
FernetEncryption,
StorageVariablesProvider,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity, VariablesEntity
from .models.variables_adapter import VariablesAdapter
self._db_manager = self.create_or_get_db_manager()
self._serve_config = ServeConfig.from_app_config(
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._db_manager = self.create_or_get_db_manager()
storage_adapter = VariablesAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
self._db_manager,
VariablesEntity,
storage_adapter,
serializer,
)
self._variables_provider = StorageVariablesProvider(
storage=storage,
encryption=FernetEncryption(self._serve_config.encrypt_key),
system_app=self._system_app,
)
@property
def variables_provider(self):
"""Get the variables provider of the serve app with db storage"""
return self._variables_provider

View File

@@ -0,0 +1,148 @@
from typing import List, Optional
from dbgpt import SystemApp
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
from dbgpt.serve.core import BaseService
from ..api.schemas import VariablesRequest, VariablesResponse
from ..config import (
SERVE_CONFIG_KEY_PREFIX,
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
ServeConfig,
)
from ..models.models import VariablesDao, VariablesEntity
class VariablesService(
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
):
"""Variables service"""
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: VariablesDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or VariablesDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> VariablesDao:
"""Returns the internal DAO."""
return self._dao
@property
def variables_provider(self) -> VariablesProvider:
"""Returns the internal VariablesProvider.
Returns:
VariablesProvider: The internal VariablesProvider
"""
variables_provider = VariablesProvider.get_instance(
self._system_app, default_component=None
)
if variables_provider:
return variables_provider
else:
from ..serve import Serve
variables_provider = Serve.get_instance(self._system_app).variables_provider
self._system_app.register_instance(variables_provider)
return variables_provider
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: VariablesRequest) -> VariablesResponse:
"""Create a new entity
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
"""Update variables.
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
)
exist_value = self.variables_provider.get(
variables.identifier.str_identifier, None
)
if exist_value is None:
raise ValueError(
f"Variable {variables.identifier.str_identifier} not found"
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
"""List all variables."""
return self.dao.get_list({"enabled": True, "category": category})