mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Support dag scope variables
This commit is contained in:
@@ -303,6 +303,7 @@ class Config(metaclass=Singleton):
|
||||
)
|
||||
# global dbgpt api key
|
||||
self.API_KEYS = os.getenv("API_KEYS", None)
|
||||
self.ENCRYPT_KEY = os.getenv("ENCRYPT_KEY", "your_secret_key")
|
||||
|
||||
# Non-streaming scene retries
|
||||
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""Import all models to make sure they are registered with SQLAlchemy.
|
||||
"""
|
||||
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
|
||||
@@ -11,6 +12,7 @@ from dbgpt.serve.agent.app.recommend_question.recommend_question import (
|
||||
from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity
|
||||
from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity
|
||||
from dbgpt.storage.chat_history.chat_history_db import (
|
||||
@@ -32,4 +34,5 @@ _MODELS = [
|
||||
ModelInstanceEntity,
|
||||
FlowServeEntity,
|
||||
RecommendQuestionEntity,
|
||||
FlowVariableEntity,
|
||||
]
|
||||
|
@@ -7,6 +7,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
|
||||
if cfg.API_KEYS:
|
||||
system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS)
|
||||
if cfg.ENCRYPT_KEY:
|
||||
system_app.config.set("dbgpt.app.global.encrypt_key", cfg.ENCRYPT_KEY)
|
||||
|
||||
# ################################ Prompt Serve Register Begin ######################################
|
||||
from dbgpt.serve.prompt.serve import (
|
||||
|
@@ -89,6 +89,7 @@ class ComponentType(str, Enum):
|
||||
CONNECTOR_MANAGER = "dbgpt_connector_manager"
|
||||
AGENT_MANAGER = "dbgpt_agent_manager"
|
||||
RESOURCE_MANAGER = "dbgpt_resource_manager"
|
||||
VARIABLES_PROVIDER = "dbgpt_variables_provider"
|
||||
|
||||
|
||||
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
|
||||
|
@@ -11,7 +11,18 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@@ -23,6 +34,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...interface.variables import VariablesProvider
|
||||
|
||||
|
||||
def _is_async_context():
|
||||
try:
|
||||
@@ -128,6 +142,8 @@ class DAGVar:
|
||||
# The executor for current DAG, this is used run some sync tasks in async DAG
|
||||
_executor: Optional[Executor] = None
|
||||
|
||||
_variables_provider: Optional["VariablesProvider"] = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
"""Enter a DAG context.
|
||||
@@ -221,6 +237,24 @@ class DAGVar:
|
||||
"""
|
||||
cls._executor = executor
|
||||
|
||||
@classmethod
|
||||
def get_variables_provider(cls) -> Optional["VariablesProvider"]:
|
||||
"""Get the current variables provider.
|
||||
|
||||
Returns:
|
||||
Optional[VariablesProvider]: The current variables provider
|
||||
"""
|
||||
return cls._variables_provider
|
||||
|
||||
@classmethod
|
||||
def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None:
|
||||
"""Set the current variables provider.
|
||||
|
||||
Args:
|
||||
variables_provider (VariablesProvider): The variables provider to set
|
||||
"""
|
||||
cls._variables_provider = variables_provider
|
||||
|
||||
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG."""
|
||||
|
@@ -7,6 +7,7 @@ from ..util.parameter_util import ( # noqa: F401
|
||||
BaseDynamicOptions,
|
||||
FunctionDynamicOptions,
|
||||
OptionValue,
|
||||
VariablesDynamicOptions,
|
||||
)
|
||||
from .base import ( # noqa: F401
|
||||
IOField,
|
||||
@@ -35,4 +36,5 @@ __ALL__ = [
|
||||
"IOField",
|
||||
"BaseDynamicOptions",
|
||||
"FunctionDynamicOptions",
|
||||
"VariablesDynamicOptions",
|
||||
]
|
||||
|
@@ -6,7 +6,7 @@ import inspect
|
||||
from abc import ABC
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
@@ -15,12 +15,14 @@ from dbgpt._private.pydantic import (
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel.util.parameter_util import (
|
||||
BaseDynamicOptions,
|
||||
OptionValue,
|
||||
RefreshOptionRequest,
|
||||
)
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
from dbgpt.util.executor_utils import DefaultExecutorFactory, blocking_func_to_async
|
||||
|
||||
from .exceptions import FlowMetadataException, FlowParameterMetadataException
|
||||
from .ui import UIComponent
|
||||
@@ -490,11 +492,19 @@ class Parameter(TypeMetadata, Serializable):
|
||||
dict_value["ui"] = self.ui.to_dict()
|
||||
return dict_value
|
||||
|
||||
def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict:
|
||||
async def refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> Dict:
|
||||
"""Refresh the options of the parameter.
|
||||
|
||||
Args:
|
||||
request (RefreshOptionRequest): The request to refresh the options.
|
||||
trigger (Literal["default", "http"], optional): The trigger type.
|
||||
Defaults to "default".
|
||||
system_app (Optional[SystemApp], optional): The system app.
|
||||
|
||||
Returns:
|
||||
Dict: The response.
|
||||
@@ -503,7 +513,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
if not self.options:
|
||||
dict_value["options"] = None
|
||||
elif isinstance(self.options, BaseDynamicOptions):
|
||||
values = self.options.refresh(request)
|
||||
values = self.options.refresh(request, trigger, system_app)
|
||||
dict_value["options"] = [value.to_dict() for value in values]
|
||||
else:
|
||||
dict_value["options"] = [value.to_dict() for value in self.options]
|
||||
@@ -793,18 +803,56 @@ class BaseMetadata(BaseResource):
|
||||
]
|
||||
return dict_value
|
||||
|
||||
def refresh(self, request: List[RefreshOptionRequest]) -> Dict:
|
||||
"""Refresh the metadata."""
|
||||
async def refresh(
|
||||
self,
|
||||
request: List[RefreshOptionRequest],
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> Dict:
|
||||
"""Refresh the metadata.
|
||||
|
||||
Args:
|
||||
request (List[RefreshOptionRequest]): The refresh request
|
||||
trigger (Literal["default", "http"]): The trigger type, how to trigger
|
||||
the refresh
|
||||
system_app (Optional[SystemApp]): The system app
|
||||
"""
|
||||
executor = DefaultExecutorFactory.get_instance(system_app).create()
|
||||
|
||||
name_to_request = {req.name: req for req in request}
|
||||
parameter_requests = {
|
||||
parameter.name: name_to_request.get(parameter.name)
|
||||
for parameter in self.parameters
|
||||
}
|
||||
dict_value = self.to_dict()
|
||||
dict_value["parameters"] = [
|
||||
parameter.refresh(parameter_requests.get(parameter.name))
|
||||
for parameter in self.parameters
|
||||
]
|
||||
dict_value = model_to_dict(self, exclude={"parameters"})
|
||||
parameters = []
|
||||
for parameter in self.parameters:
|
||||
parameter_dict = parameter.to_dict()
|
||||
parameter_request = parameter_requests.get(parameter.name)
|
||||
if not parameter.options:
|
||||
options = None
|
||||
elif isinstance(parameter.options, BaseDynamicOptions):
|
||||
options_obj = parameter.options
|
||||
if options_obj.support_async(system_app, parameter_request):
|
||||
values = await options_obj.async_refresh(
|
||||
parameter_request, trigger, system_app
|
||||
)
|
||||
else:
|
||||
values = await blocking_func_to_async(
|
||||
executor,
|
||||
options_obj.refresh,
|
||||
parameter_request,
|
||||
trigger,
|
||||
system_app,
|
||||
)
|
||||
options = [value.to_dict() for value in values]
|
||||
else:
|
||||
options = [value.to_dict() for value in self.options]
|
||||
parameter_dict["options"] = options
|
||||
parameters.append(parameter_dict)
|
||||
|
||||
dict_value["parameters"] = parameters
|
||||
|
||||
return dict_value
|
||||
|
||||
|
||||
@@ -1090,14 +1138,23 @@ class FlowRegistry:
|
||||
"""Get the metadata list."""
|
||||
return [item.metadata.to_dict() for item in self._registry.values()]
|
||||
|
||||
def refresh(
|
||||
self, key: str, is_operator: bool, request: List[RefreshOptionRequest]
|
||||
async def refresh(
|
||||
self,
|
||||
key: str,
|
||||
is_operator: bool,
|
||||
request: List[RefreshOptionRequest],
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> Dict:
|
||||
"""Refresh the metadata."""
|
||||
if is_operator:
|
||||
return _get_operator_class(key).metadata.refresh(request) # type: ignore
|
||||
return await _get_operator_class(key).metadata.refresh( # type: ignore
|
||||
request, trigger, system_app
|
||||
)
|
||||
else:
|
||||
return _get_resource_class(key).metadata.refresh(request)
|
||||
return await _get_resource_class(key).metadata.refresh(
|
||||
request, trigger, system_app
|
||||
)
|
||||
|
||||
|
||||
_OPERATOR_REGISTRY: FlowRegistry = FlowRegistry()
|
||||
|
@@ -71,7 +71,7 @@ class PanelEditorMixin(BaseModel):
|
||||
class UIComponent(RefreshableMixin, Serializable, BaseModel):
|
||||
"""UI component."""
|
||||
|
||||
class UIAttribute(StatusMixin, BaseModel):
|
||||
class UIAttribute(BaseModel):
|
||||
"""Base UI attribute."""
|
||||
|
||||
disabled: bool = Field(
|
||||
@@ -106,7 +106,7 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
|
||||
class UISelect(UIComponent):
|
||||
"""Select component."""
|
||||
|
||||
class UIAttribute(UIComponent.UIAttribute):
|
||||
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
|
||||
"""Select attribute."""
|
||||
|
||||
show_search: bool = Field(
|
||||
@@ -138,7 +138,7 @@ class UISelect(UIComponent):
|
||||
class UICascader(UIComponent):
|
||||
"""Cascader component."""
|
||||
|
||||
class UIAttribute(UIComponent.UIAttribute):
|
||||
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
|
||||
"""Cascader attribute."""
|
||||
|
||||
show_search: bool = Field(
|
||||
@@ -178,7 +178,7 @@ class UICheckbox(UIComponent):
|
||||
class UIDatePicker(UIComponent):
|
||||
"""Date picker component."""
|
||||
|
||||
class UIAttribute(UIComponent.UIAttribute):
|
||||
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
|
||||
"""Date picker attribute."""
|
||||
|
||||
placement: Optional[
|
||||
@@ -199,7 +199,7 @@ class UIDatePicker(UIComponent):
|
||||
class UIInput(UIComponent):
|
||||
"""Input component."""
|
||||
|
||||
class UIAttribute(UIComponent.UIAttribute):
|
||||
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
|
||||
"""Input attribute."""
|
||||
|
||||
prefix: Optional[str] = Field(
|
||||
@@ -216,7 +216,7 @@ class UIInput(UIComponent):
|
||||
None,
|
||||
description="Whether to show count",
|
||||
)
|
||||
maxlength: Optional[int] = Field(
|
||||
max_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The maximum length of the input",
|
||||
)
|
||||
@@ -294,7 +294,7 @@ class UISlider(UIComponent):
|
||||
class UITimePicker(UIComponent):
|
||||
"""Time picker component."""
|
||||
|
||||
class UIAttribute(UIComponent.UIAttribute):
|
||||
class UIAttribute(StatusMixin, UIComponent.UIAttribute):
|
||||
"""Time picker attribute."""
|
||||
|
||||
format: Optional[str] = Field(
|
||||
@@ -377,15 +377,20 @@ class UIUpload(UIComponent):
|
||||
)
|
||||
|
||||
|
||||
class UIVariableInput(UIInput):
|
||||
"""Variable input component."""
|
||||
class UIVariablesInput(UIInput):
|
||||
"""Variables input component."""
|
||||
|
||||
ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore
|
||||
ui_type: Literal["variable"] = Field("variables", frozen=True) # type: ignore
|
||||
key: str = Field(..., description="The key of the variable")
|
||||
key_type: Literal["common", "secret"] = Field(
|
||||
"common",
|
||||
description="The type of the key",
|
||||
)
|
||||
scope: str = Field("global", description="The scope of the variables")
|
||||
scope_key: Optional[str] = Field(
|
||||
None,
|
||||
description="The key of the scope",
|
||||
)
|
||||
refresh: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether to enable the refresh",
|
||||
@@ -396,7 +401,7 @@ class UIVariableInput(UIInput):
|
||||
self._check_options(parameter_dict.get("options", {}))
|
||||
|
||||
|
||||
class UIPasswordInput(UIVariableInput):
|
||||
class UIPasswordInput(UIVariablesInput):
|
||||
"""Password input component."""
|
||||
|
||||
ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore
|
||||
|
@@ -2,10 +2,12 @@
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from contextvars import ContextVar
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
@@ -29,6 +31,11 @@ from dbgpt.util.tracer import root_tracer
|
||||
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
|
||||
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...interface.variables import VariablesProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict[str, Any], Any]
|
||||
@@ -92,6 +99,9 @@ class BaseOperatorMeta(ABCMeta):
|
||||
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
||||
)
|
||||
executor = kwargs.get("executor") or DAGVar.get_executor()
|
||||
variables_provider = (
|
||||
kwargs.get("variables_provider") or DAGVar.get_variables_provider()
|
||||
)
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
@@ -102,14 +112,24 @@ class BaseOperatorMeta(ABCMeta):
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
DAGVar.set_executor(executor)
|
||||
if not variables_provider:
|
||||
from ...interface.variables import VariablesProvider
|
||||
|
||||
if system_app:
|
||||
variables_provider = system_app.get_component(
|
||||
ComponentType.VARIABLES_PROVIDER,
|
||||
VariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
else:
|
||||
from ...interface.variables import StorageVariablesProvider
|
||||
|
||||
variables_provider = StorageVariablesProvider()
|
||||
DAGVar.set_variables_provider(variables_provider)
|
||||
|
||||
if not task_id and dag:
|
||||
task_id = dag._new_node_id()
|
||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
|
||||
# for arg in sig_cache.parameters:
|
||||
# if arg not in kwargs:
|
||||
# kwargs[arg] = default_args[arg]
|
||||
if not kwargs.get("dag"):
|
||||
kwargs["dag"] = dag
|
||||
if not kwargs.get("task_id"):
|
||||
@@ -120,6 +140,8 @@ class BaseOperatorMeta(ABCMeta):
|
||||
kwargs["system_app"] = system_app
|
||||
if not kwargs.get("executor"):
|
||||
kwargs["executor"] = executor
|
||||
if not kwargs.get("variables_provider"):
|
||||
kwargs["variables_provider"] = variables_provider
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
@@ -150,6 +172,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
dag: Optional[DAG] = None,
|
||||
runner: Optional[WorkflowRunner] = None,
|
||||
can_skip_in_branch: bool = True,
|
||||
variables_provider: Optional["VariablesProvider"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Create a BaseOperator with an optional workflow runner.
|
||||
@@ -171,6 +194,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: Optional[DAGContext] = None
|
||||
self._can_skip_in_branch = can_skip_in_branch
|
||||
self._variables_provider = variables_provider
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
@@ -199,6 +223,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
if not task_log_id:
|
||||
raise ValueError(f"The task log ID can't be empty, current node {self}")
|
||||
CURRENT_DAG_CONTEXT.set(dag_ctx)
|
||||
# Resolve variables
|
||||
await self._resolve_variables(dag_ctx)
|
||||
return await self._do_run(dag_ctx)
|
||||
|
||||
@abstractmethod
|
||||
@@ -347,6 +373,21 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Check if the operator can be skipped in the branch."""
|
||||
return self._can_skip_in_branch
|
||||
|
||||
async def _resolve_variables(self, _: DAGContext):
|
||||
from ...interface.variables import VariablesPlaceHolder
|
||||
|
||||
if not self._variables_provider:
|
||||
return
|
||||
for attr, value in self.__dict__.items():
|
||||
if isinstance(value, VariablesPlaceHolder):
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
value.parse, self._variables_provider
|
||||
)
|
||||
logger.debug(
|
||||
f"Resolve variable {attr} with value {resolved_value} for {self}"
|
||||
)
|
||||
setattr(self, attr, resolved_value)
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
"""Initialize the default runner."""
|
||||
|
111
dbgpt/core/awel/tests/test_dag_variables.py
Normal file
111
dbgpt/core/awel/tests/test_dag_variables.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...interface.variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
VariablesPlaceHolder,
|
||||
)
|
||||
from .. import DAG, DAGVar, InputOperator, MapOperator, SimpleInputSource
|
||||
|
||||
|
||||
class VariablesOperator(MapOperator[str, str]):
|
||||
def __init__(self, int_var: int, str_var: str, secret: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._int_var = int_var
|
||||
self._str_var = str_var
|
||||
self._secret = secret
|
||||
|
||||
async def map(self, x: str) -> str:
|
||||
return (
|
||||
f"x: {x}, int_var: {self._int_var}, str_var: {self._str_var}, "
|
||||
f"secret: {self._secret}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_dag():
|
||||
with DAG("test_dag") as dag:
|
||||
input_node = InputOperator(input_source=SimpleInputSource.from_callable())
|
||||
map_node = MapOperator(lambda x: x * 2)
|
||||
input_node >> map_node
|
||||
return dag
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_variables(**kwargs):
|
||||
variables_provider = StorageVariablesProvider()
|
||||
DAGVar.set_variables_provider(variables_provider)
|
||||
|
||||
vars = kwargs.get("vars")
|
||||
variables = {}
|
||||
if vars and isinstance(vars, dict):
|
||||
for param_key, param_var in vars.items():
|
||||
key = param_var.get("key")
|
||||
value = param_var.get("value")
|
||||
value_type = param_var.get("value_type")
|
||||
category = param_var.get("category", "common")
|
||||
id = VariablesIdentifier.from_str_identifier(key)
|
||||
variables_provider.save(
|
||||
StorageVariables.from_identifier(
|
||||
id, value, value_type, label="", category=category
|
||||
)
|
||||
)
|
||||
variables[param_key] = VariablesPlaceHolder(param_key, key, value_type)
|
||||
else:
|
||||
raise ValueError("vars is required.")
|
||||
|
||||
with DAG("simple_dag") as dag:
|
||||
map_node = VariablesOperator(**variables)
|
||||
yield map_node
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def variables_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_variables(**param) as node:
|
||||
yield node
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_dag(default_dag: DAG):
|
||||
leaf_node = default_dag.leaf_nodes[0]
|
||||
res = await leaf_node.call(2)
|
||||
assert res == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"variables_node",
|
||||
[
|
||||
(
|
||||
{
|
||||
"vars": {
|
||||
"int_var": {
|
||||
"key": "int_key@my_int_var@global",
|
||||
"value": 0,
|
||||
"value_type": "int",
|
||||
},
|
||||
"str_var": {
|
||||
"key": "str_key@my_str_var@global",
|
||||
"value": "1",
|
||||
"value_type": "str",
|
||||
},
|
||||
"secret": {
|
||||
"key": "secret_key@my_secret_var@global",
|
||||
"value": "2131sdsdf",
|
||||
"value_type": "str",
|
||||
"category": "secret",
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
],
|
||||
indirect=["variables_node"],
|
||||
)
|
||||
async def test_input_nodes(variables_node: VariablesOperator):
|
||||
res = await variables_node.call("test")
|
||||
assert res == "x: test, int_var: 0, str_var: 1, secret: 2131sdsdf"
|
@@ -2,9 +2,10 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_validator
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
_DEFAULT_DYNAMIC_REGISTRY = {}
|
||||
@@ -29,6 +30,21 @@ class RefreshOptionRequest(BaseModel):
|
||||
depends: Optional[List[RefreshOptionDependency]] = Field(
|
||||
None, description="The depends of the refresh config"
|
||||
)
|
||||
variables_key: Optional[str] = Field(
|
||||
None, description="The variables key to refresh"
|
||||
)
|
||||
variables_scope: Optional[str] = Field(
|
||||
None, description="The variables scope to refresh"
|
||||
)
|
||||
variables_scope_key: Optional[str] = Field(
|
||||
None, description="The variables scope key to refresh"
|
||||
)
|
||||
variables_sys_code: Optional[str] = Field(
|
||||
None, description="The system code to refresh"
|
||||
)
|
||||
variables_user_name: Optional[str] = Field(
|
||||
None, description="The user name to refresh"
|
||||
)
|
||||
|
||||
|
||||
class OptionValue(Serializable, BaseModel):
|
||||
@@ -49,13 +65,57 @@ class OptionValue(Serializable, BaseModel):
|
||||
class BaseDynamicOptions(Serializable, BaseModel, ABC):
|
||||
"""The base dynamic options."""
|
||||
|
||||
def support_async(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
) -> bool:
|
||||
"""Whether the dynamic options support async.
|
||||
|
||||
Args:
|
||||
system_app (Optional[SystemApp]): The system app
|
||||
request (Optional[RefreshOptionRequest]): The refresh request
|
||||
|
||||
Returns:
|
||||
bool: Whether the dynamic options support async
|
||||
"""
|
||||
return False
|
||||
|
||||
def option_values(self) -> List[OptionValue]:
|
||||
"""Return the option values of the parameter."""
|
||||
return self.refresh(None)
|
||||
|
||||
@abstractmethod
|
||||
def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options."""
|
||||
def refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options.
|
||||
|
||||
Args:
|
||||
request (Optional[RefreshOptionRequest]): The refresh request
|
||||
trigger (Literal["default", "http"]): The trigger type, how to trigger
|
||||
the refresh
|
||||
system_app (Optional[SystemApp]): The system app
|
||||
"""
|
||||
|
||||
async def async_refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options async.
|
||||
|
||||
Args:
|
||||
request (Optional[RefreshOptionRequest]): The refresh request
|
||||
trigger (Literal["default", "http"]): The trigger type, how to trigger
|
||||
the refresh
|
||||
system_app (Optional[SystemApp]): The system app
|
||||
"""
|
||||
raise NotImplementedError("The dynamic options does not support async.")
|
||||
|
||||
|
||||
class FunctionDynamicOptions(BaseDynamicOptions):
|
||||
@@ -68,7 +128,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
|
||||
..., description="The unique id of the function to generate the dynamic options"
|
||||
)
|
||||
|
||||
def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]:
|
||||
def refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options."""
|
||||
if not request or not request.depends:
|
||||
return self.func()
|
||||
@@ -96,6 +161,109 @@ class FunctionDynamicOptions(BaseDynamicOptions):
|
||||
return {"func_id": self.func_id}
|
||||
|
||||
|
||||
class VariablesDynamicOptions(BaseDynamicOptions):
|
||||
"""The variables dynamic options."""
|
||||
|
||||
def support_async(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
) -> bool:
|
||||
"""Whether the dynamic options support async."""
|
||||
if not system_app or not request or not request.variables_key:
|
||||
return False
|
||||
|
||||
from ...interface.variables import BuiltinVariablesProvider
|
||||
|
||||
provider: BuiltinVariablesProvider = system_app.get_component(
|
||||
request.variables_key,
|
||||
component_type=BuiltinVariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
if not provider:
|
||||
return False
|
||||
return provider.support_async()
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options."""
|
||||
if (
|
||||
trigger == "default"
|
||||
or not request
|
||||
or not request.variables_key
|
||||
or not request.variables_scope
|
||||
):
|
||||
# Only refresh when trigger is http and request is not None
|
||||
return []
|
||||
if not system_app:
|
||||
raise ValueError("The system app is required when refresh the variables.")
|
||||
from ...interface.variables import VariablesProvider
|
||||
|
||||
vp: VariablesProvider = VariablesProvider.get_instance(system_app)
|
||||
variables = vp.get_variables(
|
||||
key=request.variables_key,
|
||||
scope=request.variables_scope,
|
||||
scope_key=request.variables_scope_key,
|
||||
sys_code=request.variables_sys_code,
|
||||
user_name=request.variables_user_name,
|
||||
)
|
||||
options = []
|
||||
for var in variables:
|
||||
options.append(
|
||||
OptionValue(
|
||||
label=var.label,
|
||||
name=var.name,
|
||||
value=var.value,
|
||||
)
|
||||
)
|
||||
return options
|
||||
|
||||
async def async_refresh(
|
||||
self,
|
||||
request: Optional[RefreshOptionRequest] = None,
|
||||
trigger: Literal["default", "http"] = "default",
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> List[OptionValue]:
|
||||
"""Refresh the dynamic options async."""
|
||||
if (
|
||||
trigger == "default"
|
||||
or not request
|
||||
or not request.variables_key
|
||||
or not request.variables_scope
|
||||
):
|
||||
return []
|
||||
if not system_app:
|
||||
raise ValueError("The system app is required when refresh the variables.")
|
||||
from ...interface.variables import VariablesProvider
|
||||
|
||||
vp: VariablesProvider = VariablesProvider.get_instance(system_app)
|
||||
variables = await vp.async_get_variables(
|
||||
key=request.variables_key,
|
||||
scope=request.variables_scope,
|
||||
scope_key=request.variables_scope_key,
|
||||
sys_code=request.variables_sys_code,
|
||||
user_name=request.variables_user_name,
|
||||
)
|
||||
options = []
|
||||
for var in variables:
|
||||
options.append(
|
||||
OptionValue(
|
||||
label=var.label,
|
||||
name=var.name,
|
||||
value=var.value,
|
||||
)
|
||||
)
|
||||
return options
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
return {"key": self.key}
|
||||
|
||||
|
||||
def _generate_unique_id(func: Callable) -> str:
|
||||
if func.__name__ == "<lambda>":
|
||||
func_id = f"lambda_{inspect.getfile(func)}_{inspect.getsourcelines(func)}"
|
||||
|
@@ -3,13 +3,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from ..awel.flow import Parameter, ResourceCategory, register_resource
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class ResourceIdentifier(Serializable, ABC):
|
||||
|
114
dbgpt/core/interface/tests/test_variables.py
Normal file
114
dbgpt/core/interface/tests/test_variables.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import base64
|
||||
import os
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from ..variables import (
|
||||
FernetEncryption,
|
||||
InMemoryStorage,
|
||||
SimpleEncryption,
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
)
|
||||
|
||||
|
||||
def test_fernet_encryption():
|
||||
key = Fernet.generate_key()
|
||||
encryption = FernetEncryption(key)
|
||||
new_encryption = FernetEncryption(key)
|
||||
data = "test_data"
|
||||
salt = "test_salt"
|
||||
|
||||
encrypted_data = encryption.encrypt(data, salt)
|
||||
assert encrypted_data != data
|
||||
|
||||
decrypted_data = encryption.decrypt(encrypted_data, salt)
|
||||
assert decrypted_data == data
|
||||
assert decrypted_data == new_encryption.decrypt(encrypted_data, salt)
|
||||
|
||||
|
||||
def test_simple_encryption():
|
||||
key = base64.b64encode(os.urandom(32)).decode()
|
||||
encryption = SimpleEncryption(key)
|
||||
data = "test_data"
|
||||
salt = "test_salt"
|
||||
|
||||
encrypted_data = encryption.encrypt(data, salt)
|
||||
assert encrypted_data != data
|
||||
|
||||
decrypted_data = encryption.decrypt(encrypted_data, salt)
|
||||
assert decrypted_data == data
|
||||
|
||||
|
||||
def test_storage_variables_provider():
|
||||
storage = InMemoryStorage()
|
||||
encryption = SimpleEncryption()
|
||||
provider = StorageVariablesProvider(storage, encryption)
|
||||
|
||||
full_key = "key@name@global"
|
||||
value = "secret_value"
|
||||
value_type = "str"
|
||||
label = "test_label"
|
||||
|
||||
id = VariablesIdentifier.from_str_identifier(full_key)
|
||||
provider.save(
|
||||
StorageVariables.from_identifier(
|
||||
id, value, value_type, label, category="secret"
|
||||
)
|
||||
)
|
||||
|
||||
loaded_variable_value = provider.get(full_key)
|
||||
assert loaded_variable_value == value
|
||||
|
||||
|
||||
def test_variables_identifier():
|
||||
full_key = "key@name@global@scope_key@sys_code@user_name"
|
||||
identifier = VariablesIdentifier.from_str_identifier(full_key)
|
||||
|
||||
assert identifier.key == "key"
|
||||
assert identifier.name == "name"
|
||||
assert identifier.scope == "global"
|
||||
assert identifier.scope_key == "scope_key"
|
||||
assert identifier.sys_code == "sys_code"
|
||||
assert identifier.user_name == "user_name"
|
||||
|
||||
str_identifier = identifier.str_identifier
|
||||
assert str_identifier == full_key
|
||||
|
||||
|
||||
def test_storage_variables():
|
||||
key = "test_key"
|
||||
name = "test_name"
|
||||
label = "test_label"
|
||||
value = "test_value"
|
||||
value_type = "str"
|
||||
category = "common"
|
||||
scope = "global"
|
||||
|
||||
storage_variable = StorageVariables(
|
||||
key=key,
|
||||
name=name,
|
||||
label=label,
|
||||
value=value,
|
||||
value_type=value_type,
|
||||
category=category,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
assert storage_variable.key == key
|
||||
assert storage_variable.name == name
|
||||
assert storage_variable.label == label
|
||||
assert storage_variable.value == value
|
||||
assert storage_variable.value_type == value_type
|
||||
assert storage_variable.category == category
|
||||
assert storage_variable.scope == scope
|
||||
|
||||
dict_representation = storage_variable.to_dict()
|
||||
assert dict_representation["key"] == key
|
||||
assert dict_representation["name"] == name
|
||||
assert dict_representation["label"] == label
|
||||
assert dict_representation["value"] == value
|
||||
assert dict_representation["value_type"] == value_type
|
||||
assert dict_representation["category"] == category
|
||||
assert dict_representation["scope"] == scope
|
678
dbgpt/core/interface/variables.py
Normal file
678
dbgpt/core/interface/variables.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""Variables Module."""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.executor_utils import (
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
blocking_func_to_async_no_executor,
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE"
|
||||
|
||||
BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows"
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes"
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
|
||||
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
|
||||
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
|
||||
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
|
||||
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
|
||||
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces"
|
||||
|
||||
|
||||
class Encryption(ABC):
|
||||
"""Encryption interface."""
|
||||
|
||||
name: str = "__abstract__"
|
||||
|
||||
@abstractmethod
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data."""
|
||||
|
||||
@abstractmethod
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data."""
|
||||
|
||||
|
||||
def _generate_key_from_password(
|
||||
password: bytes, salt: Optional[Union[str, bytes]] = None
|
||||
):
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
if salt is None:
|
||||
salt = os.urandom(16)
|
||||
elif isinstance(salt, str):
|
||||
salt = salt.encode()
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password))
|
||||
return key, salt
|
||||
|
||||
|
||||
class FernetEncryption(Encryption):
|
||||
"""Fernet encryption.
|
||||
|
||||
A symmetric encryption algorithm that uses the same key for both encryption and
|
||||
decryption which is powered by the cryptography library.
|
||||
"""
|
||||
|
||||
name = "fernet"
|
||||
|
||||
def __init__(self, key: Optional[bytes] = None):
|
||||
"""Initialize the fernet encryption."""
|
||||
if key is not None and isinstance(key, str):
|
||||
key = key.encode()
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"cryptography is required for encryption, please install by running "
|
||||
"`pip install cryptography`"
|
||||
)
|
||||
if key is None:
|
||||
key = Fernet.generate_key()
|
||||
self.key = key
|
||||
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data with the salt.
|
||||
|
||||
Args:
|
||||
data (str): The data to encrypt.
|
||||
salt (str): The salt to use, which is used to derive the key.
|
||||
|
||||
Returns:
|
||||
str: The encrypted data.
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key, salt = _generate_key_from_password(self.key, salt)
|
||||
fernet = Fernet(key)
|
||||
encrypted_secret = fernet.encrypt(data.encode()).decode()
|
||||
return encrypted_secret
|
||||
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data with the salt.
|
||||
|
||||
Args:
|
||||
encrypted_data (str): The encrypted data.
|
||||
salt (str): The salt to use, which is used to derive the key.
|
||||
|
||||
Returns:
|
||||
str: The decrypted data.
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key, salt = _generate_key_from_password(self.key, salt)
|
||||
fernet = Fernet(key)
|
||||
return fernet.decrypt(encrypted_data.encode()).decode()
|
||||
|
||||
|
||||
class SimpleEncryption(Encryption):
|
||||
"""Simple implementation of encryption.
|
||||
|
||||
A simple encryption algorithm that uses a key to XOR the data.
|
||||
"""
|
||||
|
||||
name = "simple"
|
||||
|
||||
def __init__(self, key: Optional[str] = None):
|
||||
"""Initialize the simple encryption."""
|
||||
if key is None:
|
||||
key = base64.b64encode(os.urandom(32)).decode()
|
||||
self.key = key
|
||||
|
||||
def _derive_key(self, salt: str) -> bytes:
|
||||
return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000)
|
||||
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data with the salt."""
|
||||
key = self._derive_key(salt)
|
||||
encrypted = bytes(
|
||||
x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1))
|
||||
)
|
||||
return base64.b64encode(encrypted).decode()
|
||||
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data with the salt."""
|
||||
key = self._derive_key(salt)
|
||||
data = base64.b64decode(encrypted_data)
|
||||
decrypted = bytes(
|
||||
x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1))
|
||||
)
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VariablesIdentifier(ResourceIdentifier):
|
||||
"""The variables identifier."""
|
||||
|
||||
identifier_split: str = dataclasses.field(default="@", init=False)
|
||||
|
||||
key: str
|
||||
name: str
|
||||
scope: str = "global"
|
||||
scope_key: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if not self.key or not self.name or not self.scope:
|
||||
raise ValueError("Key, name, and scope are required.")
|
||||
|
||||
if any(
|
||||
self.identifier_split in key
|
||||
for key in [
|
||||
self.key,
|
||||
self.name,
|
||||
self.scope,
|
||||
self.scope_key,
|
||||
self.sys_code,
|
||||
self.user_name,
|
||||
]
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in "
|
||||
f"key, name, scope, scope_key, sys_code, user_name."
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return self.identifier_split.join(
|
||||
key or ""
|
||||
for key in [
|
||||
self.key,
|
||||
self.name,
|
||||
self.scope,
|
||||
self.scope_key,
|
||||
self.sys_code,
|
||||
self.user_name,
|
||||
]
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the identifier.
|
||||
"""
|
||||
return {
|
||||
"key": self.key,
|
||||
"name": self.name,
|
||||
"scope": self.scope,
|
||||
"scope_key": self.scope_key,
|
||||
"sys_code": self.sys_code,
|
||||
"user_name": self.user_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_str_identifier(
|
||||
cls, str_identifier: str, identifier_split: str = "@"
|
||||
) -> "VariablesIdentifier":
|
||||
"""Create a VariablesIdentifier from a string identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The string identifier.
|
||||
identifier_split (str): The identifier split.
|
||||
|
||||
Returns:
|
||||
VariablesIdentifier: The VariablesIdentifier.
|
||||
"""
|
||||
keys = str_identifier.split(identifier_split)
|
||||
if not keys:
|
||||
raise ValueError("Invalid string identifier.")
|
||||
if len(keys) < 2:
|
||||
raise ValueError("Invalid string identifier, must have name")
|
||||
if len(keys) < 3:
|
||||
raise ValueError("Invalid string identifier, must have scope")
|
||||
|
||||
return cls(
|
||||
key=keys[0],
|
||||
name=keys[1],
|
||||
scope=keys[2],
|
||||
scope_key=keys[3] if len(keys) > 3 else None,
|
||||
sys_code=keys[4] if len(keys) > 4 else None,
|
||||
user_name=keys[5] if len(keys) > 5 else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StorageVariables(StorageItem):
|
||||
"""The storage variables."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
label: str
|
||||
value: Any
|
||||
category: Literal["common", "secret"] = "common"
|
||||
scope: str = "global"
|
||||
value_type: Optional[str] = None
|
||||
scope_key: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
encryption_method: Optional[str] = None
|
||||
salt: Optional[str] = None
|
||||
enabled: int = 1
|
||||
|
||||
_identifier: VariablesIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = VariablesIdentifier(
|
||||
key=self.key,
|
||||
name=self.name,
|
||||
scope=self.scope,
|
||||
scope_key=self.scope_key,
|
||||
sys_code=self.sys_code,
|
||||
user_name=self.user_name,
|
||||
)
|
||||
if not self.value_type:
|
||||
self.value_type = type(self.value).__name__
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge with another storage variables."""
|
||||
if not isinstance(other, StorageVariables):
|
||||
raise ValueError(f"Cannot merge with {type(other)}")
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the storage variables to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the storage variables.
|
||||
"""
|
||||
return {
|
||||
**self._identifier.to_dict(),
|
||||
"label": self.label,
|
||||
"value": self.value,
|
||||
"value_type": self.value_type,
|
||||
"category": self.category,
|
||||
"encryption_method": self.encryption_method,
|
||||
"salt": self.salt,
|
||||
}
|
||||
|
||||
def from_object(self, other: "StorageVariables") -> None:
|
||||
"""Copy the values from another storage variables object."""
|
||||
self.label = other.label
|
||||
self.value = other.value
|
||||
self.value_type = other.value_type
|
||||
self.category = other.category
|
||||
self.scope = other.scope
|
||||
self.scope_key = other.scope_key
|
||||
self.sys_code = other.sys_code
|
||||
self.user_name = other.user_name
|
||||
self.encryption_method = other.encryption_method
|
||||
self.salt = other.salt
|
||||
|
||||
@classmethod
|
||||
def from_identifier(
|
||||
cls,
|
||||
identifier: VariablesIdentifier,
|
||||
value: Any,
|
||||
value_type: str,
|
||||
label: str = "",
|
||||
category: Literal["common", "secret"] = "common",
|
||||
encryption_method: Optional[str] = None,
|
||||
salt: Optional[str] = None,
|
||||
) -> "StorageVariables":
|
||||
"""Copy the values from an identifier."""
|
||||
return cls(
|
||||
key=identifier.key,
|
||||
name=identifier.name,
|
||||
label=label,
|
||||
value=value,
|
||||
value_type=value_type,
|
||||
category=category,
|
||||
scope=identifier.scope,
|
||||
scope_key=identifier.scope_key,
|
||||
sys_code=identifier.sys_code,
|
||||
user_name=identifier.user_name,
|
||||
encryption_method=encryption_method,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
|
||||
class VariablesProvider(BaseComponent, ABC):
|
||||
"""The variables provider interface."""
|
||||
|
||||
name = ComponentType.VARIABLES_PROVIDER.value
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
|
||||
@abstractmethod
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get variables by key."""
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get variables by key async."""
|
||||
raise NotImplementedError("Current variables provider does not support async.")
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the variables provider support async."""
|
||||
return False
|
||||
|
||||
|
||||
class VariablesPlaceHolder:
|
||||
"""The variables place holder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_name: str,
|
||||
full_key: str,
|
||||
value_type: str,
|
||||
default_value: Any = _EMPTY_DEFAULT_VALUE,
|
||||
):
|
||||
"""Initialize the variables place holder."""
|
||||
self.param_name = param_name
|
||||
self.full_key = full_key
|
||||
self.value_type = value_type
|
||||
self.default_value = default_value
|
||||
|
||||
def parse(self, variables_provider: VariablesProvider) -> Any:
|
||||
"""Parse the variables."""
|
||||
value = variables_provider.get(self.full_key, self.default_value)
|
||||
if value:
|
||||
return self._cast_to_type(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def _cast_to_type(self, value: Any) -> Any:
|
||||
if self.value_type == "str":
|
||||
return str(value)
|
||||
elif self.value_type == "int":
|
||||
return int(value)
|
||||
elif self.value_type == "float":
|
||||
return float(value)
|
||||
elif self.value_type == "bool":
|
||||
if value.lower() in ["true", "1"]:
|
||||
return True
|
||||
elif value.lower() in ["false", "0"]:
|
||||
return False
|
||||
else:
|
||||
return bool(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
return (
|
||||
f"<VariablesPlaceHolder "
|
||||
f"{self.param_name} {self.full_key} {self.value_type}>"
|
||||
)
|
||||
|
||||
|
||||
class StorageVariablesProvider(VariablesProvider):
|
||||
"""The storage variables provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface] = None,
|
||||
encryption: Optional[Encryption] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
key: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the storage variables provider."""
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self.system_app = system_app
|
||||
self.encryption = encryption or SimpleEncryption(key)
|
||||
|
||||
self.storage = storage
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the storage variables provider."""
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
key = VariablesIdentifier.from_str_identifier(full_key)
|
||||
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
|
||||
if variable is None:
|
||||
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||
raise ValueError(f"Variable {full_key} not found")
|
||||
return default_value
|
||||
variable.value = self.deserialize_value(variable.value)
|
||||
if (
|
||||
variable.value is not None
|
||||
and variable.category == "secret"
|
||||
and variable.encryption_method
|
||||
and variable.salt
|
||||
):
|
||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||
return variable.value
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
if variables_item.category == "secret":
|
||||
salt = base64.b64encode(os.urandom(16)).decode()
|
||||
variables_item.value = self.encryption.encrypt(
|
||||
str(variables_item.value), salt
|
||||
)
|
||||
variables_item.encryption_method = self.encryption.name
|
||||
variables_item.salt = salt
|
||||
# Replace value to a json serializable object
|
||||
variables_item.value = self.serialize_value(variables_item.value)
|
||||
|
||||
self.storage.save_or_update(variables_item)
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Query variables from storage."""
|
||||
# Try to get builtin variables
|
||||
is_builtin, builtin_variables = self._get_builtins_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
variables = self.storage.query(
|
||||
QuerySpec(
|
||||
conditions={
|
||||
"key": key,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
"enabled": 1,
|
||||
}
|
||||
),
|
||||
StorageVariables,
|
||||
)
|
||||
for variable in variables:
|
||||
variable.value = self.deserialize_value(variable.value)
|
||||
return variables
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Query variables from storage async."""
|
||||
# Try to get builtin variables
|
||||
is_builtin, builtin_variables = await self._async_get_builtins_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
executor_factory: Optional[
|
||||
DefaultExecutorFactory
|
||||
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
|
||||
if executor_factory:
|
||||
return await blocking_func_to_async(
|
||||
executor_factory.create(),
|
||||
self.get_variables,
|
||||
key,
|
||||
scope,
|
||||
scope_key,
|
||||
sys_code,
|
||||
user_name,
|
||||
)
|
||||
else:
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self.get_variables, key, scope, scope_key, sys_code, user_name
|
||||
)
|
||||
|
||||
def _get_builtins_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Tuple[bool, List[StorageVariables]]:
|
||||
"""Get builtin variables."""
|
||||
if self.system_app is None:
|
||||
return False, []
|
||||
provider: BuiltinVariablesProvider = self.system_app.get_component(
|
||||
key,
|
||||
component_type=BuiltinVariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
if not provider:
|
||||
return False, []
|
||||
return True, provider.get_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
async def _async_get_builtins_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Tuple[bool, List[StorageVariables]]:
|
||||
"""Get builtin variables."""
|
||||
if self.system_app is None:
|
||||
return False, []
|
||||
provider: BuiltinVariablesProvider = self.system_app.get_component(
|
||||
key,
|
||||
component_type=BuiltinVariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
if not provider:
|
||||
return False, []
|
||||
if not provider.support_async():
|
||||
return False, []
|
||||
return True, await provider.async_get_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def serialize_value(cls, value: Any) -> str:
|
||||
"""Serialize the value."""
|
||||
value_dict = {"value": value}
|
||||
return json.dumps(value_dict, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def deserialize_value(cls, value: str) -> Any:
|
||||
"""Deserialize the value."""
|
||||
value_dict = json.loads(value)
|
||||
return value_dict["value"]
|
||||
|
||||
|
||||
class BuiltinVariablesProvider(VariablesProvider, ABC):
|
||||
"""The builtin variables provider.
|
||||
|
||||
You can implement this class to provide builtin variables. Such LLMs, agents,
|
||||
datasource, knowledge base, etc.
|
||||
"""
|
||||
|
||||
name = "dbgpt_variables_builtin"
|
||||
|
||||
def __init__(self, system_app: Optional[SystemApp] = None):
|
||||
"""Initialize the builtin variables provider."""
|
||||
self.system_app = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the builtin variables provider."""
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support get.")
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support save.")
|
@@ -1,7 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from dbgpt.serve.core.config import BaseServeConfig
|
||||
from dbgpt.serve.core.schemas import Result, add_exception_handler
|
||||
from dbgpt.serve.core.serve import BaseServe
|
||||
from dbgpt.serve.core.service import BaseService
|
||||
from dbgpt.util.executor_utils import BlockingFunction, DefaultExecutorFactory
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async as _blocking_func_to_async
|
||||
|
||||
__ALL__ = [
|
||||
"Result",
|
||||
@@ -10,3 +14,11 @@ __ALL__ = [
|
||||
"BaseService",
|
||||
"BaseServe",
|
||||
]
|
||||
|
||||
|
||||
async def blocking_func_to_async(
|
||||
system_app, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
"""Run a potentially blocking function within an executor."""
|
||||
executor = DefaultExecutorFactory.get_instance(system_app).create()
|
||||
return await _blocking_func_to_async(executor, func, *args, **kwargs)
|
||||
|
@@ -7,12 +7,19 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse
|
||||
from ..service.variables_service import VariablesService
|
||||
from .schemas import (
|
||||
RefreshNodeRequest,
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -23,7 +30,12 @@ global_system_app: Optional[SystemApp] = None
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
return Service.get_instance(global_system_app)
|
||||
|
||||
|
||||
def get_variable_service() -> VariablesService:
|
||||
"""Get the service instance"""
|
||||
return VariablesService.get_instance(global_system_app)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
@@ -261,16 +273,80 @@ async def refresh_nodes(refresh_request: RefreshNodeRequest):
|
||||
"""
|
||||
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
||||
|
||||
new_metadata = _OPERATOR_REGISTRY.refresh(
|
||||
key=refresh_request.id,
|
||||
is_operator=refresh_request.flow_type == "operator",
|
||||
request=refresh_request.refresh,
|
||||
# Make sure the variables provider is initialized
|
||||
_ = get_variable_service().variables_provider
|
||||
|
||||
new_metadata = await _OPERATOR_REGISTRY.refresh(
|
||||
refresh_request.id,
|
||||
refresh_request.flow_type == "operator",
|
||||
refresh_request.refresh,
|
||||
"http",
|
||||
global_system_app,
|
||||
)
|
||||
return Result.succ(new_metadata)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/variables",
|
||||
response_model=Result[VariablesResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def create_variables(
|
||||
variables_request: VariablesRequest,
|
||||
) -> Result[VariablesResponse]:
|
||||
"""Create a new Variables entity
|
||||
|
||||
Args:
|
||||
variables_request (VariablesRequest): The request
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, get_variable_service().create, variables_request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/variables/{v_id}",
|
||||
response_model=Result[VariablesResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def update_variables(
|
||||
v_id: int, variables_request: VariablesRequest
|
||||
) -> Result[VariablesResponse]:
|
||||
"""Update a Variables entity
|
||||
|
||||
Args:
|
||||
v_id (int): The variable id
|
||||
variables_request (VariablesRequest): The request
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, get_variable_service().update, v_id, variables_request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
from .variables_provider import (
|
||||
BuiltinAllSecretVariablesProvider,
|
||||
BuiltinAllVariablesProvider,
|
||||
BuiltinEmbeddingsVariablesProvider,
|
||||
BuiltinFlowVariablesProvider,
|
||||
BuiltinLLMVariablesProvider,
|
||||
BuiltinNodeVariablesProvider,
|
||||
)
|
||||
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
system_app.register(VariablesService)
|
||||
system_app.register(BuiltinFlowVariablesProvider)
|
||||
system_app.register(BuiltinNodeVariablesProvider)
|
||||
system_app.register(BuiltinAllVariablesProvider)
|
||||
system_app.register(BuiltinAllSecretVariablesProvider)
|
||||
system_app.register(BuiltinLLMVariablesProvider)
|
||||
system_app.register(BuiltinEmbeddingsVariablesProvider)
|
||||
global_system_app = system_app
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import List, Literal
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
@@ -17,6 +17,69 @@ class ServerResponse(FlowPanel):
|
||||
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
|
||||
class VariablesRequest(BaseModel):
|
||||
"""Variable request model.
|
||||
|
||||
For creating a new variable in the DB-GPT.
|
||||
"""
|
||||
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of the variable to create",
|
||||
examples=["dbgpt.model.openai.api_key"],
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the variable to create",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the variable to create",
|
||||
examples=["My First OpenAI Key"],
|
||||
)
|
||||
value: Any = Field(
|
||||
..., description="The value of the variable to create", examples=["1234567890"]
|
||||
)
|
||||
value_type: Literal["str", "int", "float", "bool"] = Field(
|
||||
"str",
|
||||
description="The type of the value of the variable to create",
|
||||
examples=["str", "int", "float", "bool"],
|
||||
)
|
||||
category: Literal["common", "secret"] = Field(
|
||||
...,
|
||||
description="The category of the variable to create",
|
||||
examples=["common"],
|
||||
)
|
||||
scope: str = Field(
|
||||
...,
|
||||
description="The scope of the variable to create",
|
||||
examples=["global"],
|
||||
)
|
||||
scope_key: Optional[str] = Field(
|
||||
...,
|
||||
description="The scope key of the variable to create",
|
||||
examples=["dbgpt"],
|
||||
)
|
||||
enabled: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether the variable is enabled",
|
||||
examples=[True],
|
||||
)
|
||||
user_name: Optional[str] = Field(None, description="User name")
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
||||
|
||||
class VariablesResponse(VariablesRequest):
|
||||
"""Variable response model."""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="The id of the variable",
|
||||
examples=[1],
|
||||
)
|
||||
|
||||
|
||||
class RefreshNodeRequest(BaseModel):
|
||||
"""Flow response model"""
|
||||
|
||||
|
260
dbgpt/serve/flow/api/variables_provider.py
Normal file
260
dbgpt/serve/flow/api/variables_provider.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from dbgpt.core.interface.variables import (
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES,
|
||||
BUILTIN_VARIABLES_CORE_FLOWS,
|
||||
BUILTIN_VARIABLES_CORE_LLMS,
|
||||
BUILTIN_VARIABLES_CORE_SECRETS,
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES,
|
||||
BuiltinVariablesProvider,
|
||||
StorageVariables,
|
||||
)
|
||||
|
||||
from ..service.service import Service
|
||||
from .endpoints import get_service, get_variable_service
|
||||
from .schemas import ServerResponse
|
||||
|
||||
|
||||
class BuiltinFlowVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin flow variables provider.
|
||||
|
||||
Provide all flows by variables "${dbgpt.core.flow.flows}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_FLOWS
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
service: Service = get_service()
|
||||
page_result = service.get_list_by_page(
|
||||
{
|
||||
"user_name": user_name,
|
||||
"sys_code": sys_code,
|
||||
},
|
||||
1,
|
||||
1000,
|
||||
)
|
||||
flows: List[ServerResponse] = page_result.items
|
||||
variables = []
|
||||
for flow in flows:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=flow.name,
|
||||
label=flow.label,
|
||||
value=flow.uid,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinNodeVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin node variables provider.
|
||||
|
||||
Provide all nodes by variables "${dbgpt.core.flow.nodes}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_FLOW_NODES
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
||||
|
||||
metadata_list = _OPERATOR_REGISTRY.metadata_list()
|
||||
variables = []
|
||||
for metadata in metadata_list:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=metadata["name"],
|
||||
label=metadata["label"],
|
||||
value=metadata["id"],
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinAllVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin all variables provider.
|
||||
|
||||
Provide all variables by variables "${dbgpt.core.variables}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_VARIABLES
|
||||
|
||||
def _get_variables_from_db(
|
||||
self,
|
||||
key: str,
|
||||
scope: str,
|
||||
scope_key: Optional[str],
|
||||
sys_code: Optional[str],
|
||||
user_name: Optional[str],
|
||||
category: Literal["common", "secret"] = "common",
|
||||
) -> List[StorageVariables]:
|
||||
storage_variables = get_variable_service().list_all_variables(category)
|
||||
variables = []
|
||||
for var in storage_variables:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=var.name,
|
||||
label=var.label,
|
||||
value=var.value,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables.
|
||||
|
||||
TODO: Return all builtin variables
|
||||
"""
|
||||
return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name)
|
||||
|
||||
|
||||
class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider):
|
||||
"""Builtin all secret variables provider.
|
||||
|
||||
Provide all secret variables by variables "${dbgpt.core.secrets}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_SECRETS
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return self._get_variables_from_db(
|
||||
key, scope, scope_key, sys_code, user_name, "secret"
|
||||
)
|
||||
|
||||
|
||||
class BuiltinLLMVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin LLM variables provider.
|
||||
|
||||
Provide all LLM variables by variables "${dbgpt.core.llmv}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_LLMS
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the dynamic options support async."""
|
||||
return True
|
||||
|
||||
async def _get_models(
|
||||
self,
|
||||
key: str,
|
||||
scope: str,
|
||||
scope_key: Optional[str],
|
||||
sys_code: Optional[str],
|
||||
user_name: Optional[str],
|
||||
expect_worker_type: str = "llm",
|
||||
) -> List[StorageVariables]:
|
||||
from dbgpt.model.cluster.controller.controller import BaseModelController
|
||||
|
||||
controller = BaseModelController.get_instance(self.system_app)
|
||||
models = await controller.get_all_instances(healthy_only=True)
|
||||
model_dict = {}
|
||||
for model in models:
|
||||
worker_name, worker_type = model.model_name.split("@")
|
||||
if expect_worker_type == worker_type:
|
||||
model_dict[worker_name] = model
|
||||
variables = []
|
||||
for worker_name, model in model_dict.items():
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=worker_name,
|
||||
label=worker_name,
|
||||
value=worker_name,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return await self._get_models(key, scope, scope_key, sys_code, user_name)
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
raise NotImplementedError(
|
||||
"Not implemented get variables sync, please use async_get_variables"
|
||||
)
|
||||
|
||||
|
||||
class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider):
|
||||
"""Builtin embeddings variables provider.
|
||||
|
||||
Provide all embeddings variables by variables "${dbgpt.core.embeddings}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_EMBEDDINGS
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return await self._get_models(
|
||||
key, scope, scope_key, sys_code, user_name, "text2vec"
|
||||
)
|
@@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow."
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service"
|
||||
# Database table name
|
||||
SERVER_APP_TABLE_NAME = "dbgpt_serve_flow"
|
||||
SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig):
|
||||
load_dbgpts_interval: int = field(
|
||||
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
|
||||
)
|
||||
encrypt_key: Optional[str] = field(
|
||||
default=None, metadata={"help": "The key to encrypt the data"}
|
||||
)
|
||||
|
@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict
|
||||
from dbgpt.core.awel.flow.flow_factory import State
|
||||
from dbgpt.core.interface.variables import StorageVariablesProvider
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
from ..api.schemas import (
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
@@ -74,6 +80,56 @@ class ServeEntity(Model):
|
||||
return editable is None or editable == 0
|
||||
|
||||
|
||||
class VariablesEntity(Model):
|
||||
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
key = Column(String(128), index=True, nullable=False, comment="Variable key")
|
||||
name = Column(String(128), index=True, nullable=True, comment="Variable name")
|
||||
label = Column(String(128), nullable=True, comment="Variable label")
|
||||
value = Column(Text, nullable=True, comment="Variable value, JSON format")
|
||||
value_type = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable value type(string, int, float, bool)",
|
||||
)
|
||||
category = Column(
|
||||
String(32),
|
||||
default="common",
|
||||
nullable=True,
|
||||
comment="Variable category(common or secret)",
|
||||
)
|
||||
encryption_method = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable encryption method(fernet, simple, rsa, aes)",
|
||||
)
|
||||
salt = Column(String(128), nullable=True, comment="Variable salt")
|
||||
scope = Column(
|
||||
String(32),
|
||||
default="global",
|
||||
nullable=True,
|
||||
comment="Variable scope(global,flow,app,agent,datasource,flow:uid,"
|
||||
"flow:dag_name,agent:agent_name) etc",
|
||||
)
|
||||
scope_key = Column(
|
||||
String(256),
|
||||
nullable=True,
|
||||
comment="Variable scope key, default is empty, for scope is 'flow:uid', "
|
||||
"the scope_key is uid of flow",
|
||||
)
|
||||
enabled = Column(
|
||||
Integer,
|
||||
default=1,
|
||||
nullable=True,
|
||||
comment="Variable enabled, 0: disabled, 1: enabled",
|
||||
)
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for Flow"""
|
||||
|
||||
@@ -222,3 +278,108 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
session.merge(entry)
|
||||
session.commit()
|
||||
return self.get_one(query_request)
|
||||
|
||||
|
||||
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
|
||||
"""The DAO class for Variables"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(
|
||||
self, request: Union[VariablesRequest, Dict[str, Any]]
|
||||
) -> VariablesEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[VariablesRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
model_to_dict(request) if isinstance(request, VariablesRequest) else request
|
||||
)
|
||||
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
|
||||
enabled = 1 if request_dict.get("enabled", True) else 0
|
||||
new_dict = {
|
||||
"key": request_dict.get("key"),
|
||||
"name": request_dict.get("name"),
|
||||
"label": request_dict.get("label"),
|
||||
"value": value,
|
||||
"value_type": request_dict.get("value_type"),
|
||||
"category": request_dict.get("category"),
|
||||
"encryption_method": request_dict.get("encryption_method"),
|
||||
"salt": request_dict.get("salt"),
|
||||
"scope": request_dict.get("scope"),
|
||||
"scope_key": request_dict.get("scope_key"),
|
||||
"enabled": enabled,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
}
|
||||
entity = VariablesEntity(**new_dict)
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesRequest(
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
|
||||
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesResponse(
|
||||
id=entity.id,
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
)
|
||||
|
69
dbgpt/serve/flow/models/variables_adapter.py
Normal file
69
dbgpt/serve/flow/models/variables_adapter.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
|
||||
|
||||
from .models import VariablesEntity
|
||||
|
||||
|
||||
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
|
||||
"""Variables adapter.
|
||||
|
||||
Convert between storage format and database model.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
|
||||
"""Convert to storage format."""
|
||||
return VariablesEntity(
|
||||
key=item.key,
|
||||
name=item.name,
|
||||
label=item.label,
|
||||
value=item.value,
|
||||
value_type=item.value_type,
|
||||
category=item.category,
|
||||
encryption_method=item.encryption_method,
|
||||
salt=item.salt,
|
||||
scope=item.scope,
|
||||
scope_key=item.scope_key,
|
||||
sys_code=item.sys_code,
|
||||
user_name=item.user_name,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
|
||||
"""Convert from storage format."""
|
||||
return StorageVariables(
|
||||
key=model.key,
|
||||
name=model.name,
|
||||
label=model.label,
|
||||
value=model.value,
|
||||
value_type=model.value_type,
|
||||
category=model.category,
|
||||
encryption_method=model.encryption_method,
|
||||
salt=model.salt,
|
||||
scope=model.scope,
|
||||
scope_key=model.scope_key,
|
||||
sys_code=model.sys_code,
|
||||
user_name=model.user_name,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[VariablesEntity],
|
||||
resource_id: VariablesIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get query for identifier."""
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(VariablesEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
|
||||
|
||||
# enabled must be True
|
||||
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
|
||||
return query_obj
|
@@ -4,6 +4,7 @@ from typing import List, Optional, Union
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.variables import VariablesProvider
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
@@ -40,6 +41,8 @@ class Serve(BaseServe):
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
self._variables_provider: Optional[VariablesProvider] = None
|
||||
self._serve_config: Optional[ServeConfig] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
@@ -62,5 +65,37 @@ class Serve(BaseServe):
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
# TODO: Your code here
|
||||
from dbgpt.core.interface.variables import (
|
||||
FernetEncryption,
|
||||
StorageVariablesProvider,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.models import ServeEntity, VariablesEntity
|
||||
from .models.variables_adapter import VariablesAdapter
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
storage_adapter = VariablesAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
self._db_manager,
|
||||
VariablesEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
self._variables_provider = StorageVariablesProvider(
|
||||
storage=storage,
|
||||
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
||||
system_app=self._system_app,
|
||||
)
|
||||
|
||||
@property
|
||||
def variables_provider(self):
|
||||
"""Get the variables provider of the serve app with db storage"""
|
||||
return self._variables_provider
|
||||
|
148
dbgpt/serve/flow/service/variables_service.py
Normal file
148
dbgpt/serve/flow/service/variables_service.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
|
||||
from dbgpt.serve.core import BaseService
|
||||
|
||||
from ..api.schemas import VariablesRequest, VariablesResponse
|
||||
from ..config import (
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
|
||||
ServeConfig,
|
||||
)
|
||||
from ..models.models import VariablesDao, VariablesEntity
|
||||
|
||||
|
||||
class VariablesService(
|
||||
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
):
|
||||
"""Variables service"""
|
||||
|
||||
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: VariablesDao = dao
|
||||
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or VariablesDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> VariablesDao:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def variables_provider(self) -> VariablesProvider:
|
||||
"""Returns the internal VariablesProvider.
|
||||
|
||||
Returns:
|
||||
VariablesProvider: The internal VariablesProvider
|
||||
"""
|
||||
variables_provider = VariablesProvider.get_instance(
|
||||
self._system_app, default_component=None
|
||||
)
|
||||
if variables_provider:
|
||||
return variables_provider
|
||||
else:
|
||||
from ..serve import Serve
|
||||
|
||||
variables_provider = Serve.get_instance(self._system_app).variables_provider
|
||||
self._system_app.register_instance(variables_provider)
|
||||
return variables_provider
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def create(self, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Create a new entity
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Update variables.
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
)
|
||||
exist_value = self.variables_provider.get(
|
||||
variables.identifier.str_identifier, None
|
||||
)
|
||||
if exist_value is None:
|
||||
raise ValueError(
|
||||
f"Variable {variables.identifier.str_identifier} not found"
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
|
||||
"""List all variables."""
|
||||
return self.dao.get_list({"enabled": True, "category": category})
|
Reference in New Issue
Block a user