mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
feat(core): Support complex variables parsing
This commit is contained in:
@@ -380,27 +380,40 @@ class Parameter(TypeMetadata, Serializable):
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any):
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
|
||||
if type_cls and v is not None:
|
||||
typed_value: Any = v
|
||||
try:
|
||||
# Try to convert the value to the type.
|
||||
if type_cls == "builtins.str":
|
||||
return str(v)
|
||||
typed_value = str(v)
|
||||
elif type_cls == "builtins.int":
|
||||
return int(v)
|
||||
typed_value = int(v)
|
||||
elif type_cls == "builtins.float":
|
||||
return float(v)
|
||||
typed_value = float(v)
|
||||
elif type_cls == "builtins.bool":
|
||||
if str(v).lower() in ["false", "0", "", "no", "off"]:
|
||||
return False
|
||||
return bool(v)
|
||||
typed_value = bool(v)
|
||||
return typed_value
|
||||
except ValueError:
|
||||
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
|
||||
return v
|
||||
|
||||
def get_typed_value(self) -> Any:
|
||||
"""Get the typed value."""
|
||||
return self._covert_to_real_type(self.type_cls, self.value)
|
||||
"""Get the typed value.
|
||||
|
||||
Returns:
|
||||
Any: The typed value. VariablesPlaceHolder if the value is a variable
|
||||
string. Otherwise, the real type value.
|
||||
"""
|
||||
from ...interface.variables import VariablesPlaceHolder, is_variable_string
|
||||
|
||||
is_variables = is_variable_string(self.value) if self.value else False
|
||||
if is_variables and self.value is not None and isinstance(self.value, str):
|
||||
return VariablesPlaceHolder(self.name, self.value)
|
||||
else:
|
||||
return self._covert_to_real_type(self.type_cls, self.value)
|
||||
|
||||
def get_typed_default(self) -> Any:
|
||||
"""Get the typed default."""
|
||||
|
223
dbgpt/core/awel/flow/tests/test_flow_variables.py
Normal file
223
dbgpt/core/awel/flow/tests/test_flow_variables.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
VariablesDynamicOptions,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel
|
||||
|
||||
from ...tests.conftest import variables_provider
|
||||
|
||||
|
||||
class MyVariablesOperator(MapOperator[str, str]):
|
||||
metadata = ViewMetadata(
|
||||
label="My Test Variables Operator",
|
||||
name="my_test_variables_operator",
|
||||
category=OperatorCategory.EXAMPLE,
|
||||
description="An example flow operator that includes a variables option.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"OpenAI API Key",
|
||||
"openai_api_key",
|
||||
type=str,
|
||||
placeholder="Please select the OpenAI API key",
|
||||
description="The OpenAI API key to use.",
|
||||
options=VariablesDynamicOptions(),
|
||||
ui=ui.UIPasswordInput(
|
||||
key="dbgpt.model.openai.api_key",
|
||||
),
|
||||
),
|
||||
Parameter.build_from(
|
||||
"Model",
|
||||
"model",
|
||||
type=str,
|
||||
placeholder="Please select the model",
|
||||
description="The model to use.",
|
||||
options=VariablesDynamicOptions(),
|
||||
ui=ui.UIVariablesInput(
|
||||
key="dbgpt.model.openai.model",
|
||||
),
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"User Name",
|
||||
"user_name",
|
||||
str,
|
||||
description="The name of the user.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Model info",
|
||||
"model",
|
||||
str,
|
||||
description="The model info.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, openai_api_key: str, model: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._openai_api_key = openai_api_key
|
||||
self._model = model
|
||||
|
||||
async def map(self, user_name: str) -> str:
|
||||
dict_dict = {
|
||||
"openai_api_key": self._openai_api_key,
|
||||
"model": self._model,
|
||||
}
|
||||
json_data = json.dumps(dict_dict, ensure_ascii=False)
|
||||
return "Your name is %s, and your model info is %s." % (user_name, json_data)
|
||||
|
||||
|
||||
class EndOperator(MapOperator[str, str]):
|
||||
metadata = ViewMetadata(
|
||||
label="End Operator",
|
||||
name="end_operator",
|
||||
category=OperatorCategory.EXAMPLE,
|
||||
description="An example flow operator that ends the flow.",
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Input",
|
||||
"input",
|
||||
str,
|
||||
description="The input to the end operator.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Output",
|
||||
"output",
|
||||
str,
|
||||
description="The output of the end operator.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def map(self, input: str) -> str:
|
||||
return f"End operator received input: {input}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_flow():
|
||||
operators = [MyVariablesOperator, EndOperator]
|
||||
metadata_list = [operator.metadata.to_dict() for operator in operators]
|
||||
node_names = {}
|
||||
name_to_parameters_dict = {
|
||||
"my_test_variables_operator": {
|
||||
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
|
||||
"model": "${dbgpt.model.openai.model:default_model@global}",
|
||||
}
|
||||
}
|
||||
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
|
||||
ui_nodes = []
|
||||
for metadata in metadata_list:
|
||||
type_name = metadata["type_name"]
|
||||
name = metadata["name"]
|
||||
id = metadata["id"]
|
||||
if type_name in node_names:
|
||||
raise ValueError(f"Duplicate node type name: {type_name}")
|
||||
# Replace id to flow data id.
|
||||
metadata["id"] = f"{id}_0"
|
||||
parameters = metadata["parameters"]
|
||||
parameters_dict = name_to_parameters_dict.get(name, {})
|
||||
for parameter in parameters:
|
||||
parameter_name = parameter["name"]
|
||||
if parameter_name in parameters_dict:
|
||||
parameter["value"] = parameters_dict[parameter_name]
|
||||
ui_nodes.append(
|
||||
{
|
||||
"width": 288,
|
||||
"height": 352,
|
||||
"id": metadata["id"],
|
||||
"position": {
|
||||
"x": -149.98120112708142,
|
||||
"y": 666.9468497341901,
|
||||
"zoom": 0.0,
|
||||
},
|
||||
"type": "customNode",
|
||||
"position_absolute": {
|
||||
"x": -149.98120112708142,
|
||||
"y": 666.9468497341901,
|
||||
"zoom": 0.0,
|
||||
},
|
||||
"data": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
ui_edges = []
|
||||
source_id = name_to_metadata_dict["my_test_variables_operator"]["id"]
|
||||
target_id = name_to_metadata_dict["end_operator"]["id"]
|
||||
ui_edges.append(
|
||||
{
|
||||
"source": source_id,
|
||||
"target": target_id,
|
||||
"source_order": 0,
|
||||
"target_order": 0,
|
||||
"id": f"{source_id}|{target_id}",
|
||||
"source_handle": f"{source_id}|outputs|0",
|
||||
"target_handle": f"{target_id}|inputs|0",
|
||||
"type": "buttonedge",
|
||||
}
|
||||
)
|
||||
return {
|
||||
"nodes": ui_nodes,
|
||||
"edges": ui_edges,
|
||||
"viewport": {
|
||||
"x": 509.2191773722104,
|
||||
"y": -66.11286175905718,
|
||||
"zoom": 1.252741002590748,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"variables_provider",
|
||||
[
|
||||
(
|
||||
{
|
||||
"vars": {
|
||||
"openai_api_key": {
|
||||
"key": "${dbgpt.model.openai.api_key:my_key@global}",
|
||||
"value": "my_openai_api_key",
|
||||
"value_type": "str",
|
||||
"category": "secret",
|
||||
},
|
||||
"model": {
|
||||
"key": "${dbgpt.model.openai.model:default_model@global}",
|
||||
"value": "GPT-4o",
|
||||
"value_type": "str",
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
],
|
||||
indirect=["variables_provider"],
|
||||
)
|
||||
async def test_build_flow(json_flow, variables_provider):
|
||||
DAGVar.set_variables_provider(variables_provider)
|
||||
flow_data = FlowData(**json_flow)
|
||||
flow_panel = FlowPanel(
|
||||
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed"
|
||||
)
|
||||
factory = FlowFactory()
|
||||
dag = factory.build(flow_panel)
|
||||
|
||||
leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
|
||||
result = await leaf_node.call("Alice")
|
||||
assert (
|
||||
result
|
||||
== "End operator received input: Your name is Alice, and your model info is "
|
||||
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.'
|
||||
)
|
@@ -19,7 +19,7 @@ _UI_TYPE = Literal[
|
||||
"time_picker",
|
||||
"tree_select",
|
||||
"upload",
|
||||
"variable",
|
||||
"variables",
|
||||
"password",
|
||||
"code_editor",
|
||||
]
|
||||
|
@@ -380,6 +380,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if not self._variables_provider:
|
||||
return
|
||||
# TODO: Resolve variables parallel
|
||||
for attr, value in self.__dict__.items():
|
||||
if isinstance(value, VariablesPlaceHolder):
|
||||
resolved_value = await self.blocking_func_to_async(
|
||||
|
@@ -1,17 +1,15 @@
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from .. import (
|
||||
DAGContext,
|
||||
DefaultWorkflowRunner,
|
||||
InputOperator,
|
||||
SimpleInputSource,
|
||||
TaskState,
|
||||
WorkflowRunner,
|
||||
from ...interface.variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
)
|
||||
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@@ -102,3 +100,32 @@ async def stream_input_nodes(request):
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_variables(**kwargs):
|
||||
vp = StorageVariablesProvider()
|
||||
vars = kwargs.get("vars")
|
||||
if vars and isinstance(vars, dict):
|
||||
for param_key, param_var in vars.items():
|
||||
key = param_var.get("key")
|
||||
value = param_var.get("value")
|
||||
value_type = param_var.get("value_type")
|
||||
category = param_var.get("category", "common")
|
||||
id = VariablesIdentifier.from_str_identifier(key)
|
||||
vp.save(
|
||||
StorageVariables.from_identifier(
|
||||
id, value, value_type, label="", category=category
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("vars is required.")
|
||||
|
||||
yield vp
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def variables_provider(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_variables(**param) as vp:
|
||||
yield vp
|
||||
|
@@ -54,7 +54,7 @@ async def _create_variables(**kwargs):
|
||||
id, value, value_type, label="", category=category
|
||||
)
|
||||
)
|
||||
variables[param_key] = VariablesPlaceHolder(param_key, key, value_type)
|
||||
variables[param_key] = VariablesPlaceHolder(param_key, key)
|
||||
else:
|
||||
raise ValueError("vars is required.")
|
||||
|
||||
@@ -85,17 +85,17 @@ async def test_default_dag(default_dag: DAG):
|
||||
{
|
||||
"vars": {
|
||||
"int_var": {
|
||||
"key": "int_key@my_int_var@global",
|
||||
"key": "${int_key:my_int_var@global}",
|
||||
"value": 0,
|
||||
"value_type": "int",
|
||||
},
|
||||
"str_var": {
|
||||
"key": "str_key@my_str_var@global",
|
||||
"key": "${str_key:my_str_var@global}",
|
||||
"value": "1",
|
||||
"value_type": "str",
|
||||
},
|
||||
"secret": {
|
||||
"key": "secret_key@my_secret_var@global",
|
||||
"key": "${secret_key:my_secret_var@global}",
|
||||
"value": "2131sdsdf",
|
||||
"value_type": "str",
|
||||
"category": "secret",
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
@@ -10,6 +11,8 @@ from ..variables import (
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
build_variable_string,
|
||||
parse_variable,
|
||||
)
|
||||
|
||||
|
||||
@@ -46,7 +49,7 @@ def test_storage_variables_provider():
|
||||
encryption = SimpleEncryption()
|
||||
provider = StorageVariablesProvider(storage, encryption)
|
||||
|
||||
full_key = "key@name@global"
|
||||
full_key = "${key:name@global}"
|
||||
value = "secret_value"
|
||||
value_type = "str"
|
||||
label = "test_label"
|
||||
@@ -63,7 +66,7 @@ def test_storage_variables_provider():
|
||||
|
||||
|
||||
def test_variables_identifier():
|
||||
full_key = "key@name@global@scope_key@sys_code@user_name"
|
||||
full_key = "${key:name@global:scope_key#sys_code%user_name}"
|
||||
identifier = VariablesIdentifier.from_str_identifier(full_key)
|
||||
|
||||
assert identifier.key == "key"
|
||||
@@ -112,3 +115,213 @@ def test_storage_variables():
|
||||
assert dict_representation["value_type"] == value_type
|
||||
assert dict_representation["category"] == category
|
||||
assert dict_representation["scope"] == scope
|
||||
|
||||
|
||||
def generate_test_cases(enable_escape=False):
|
||||
# Define possible values for each field, including special characters for escaping
|
||||
_EMPTY_ = "___EMPTY___"
|
||||
fields = {
|
||||
"name": [
|
||||
None,
|
||||
"test_name",
|
||||
"test:name" if enable_escape else _EMPTY_,
|
||||
"test::name" if enable_escape else _EMPTY_,
|
||||
"test#name" if enable_escape else _EMPTY_,
|
||||
"test##name" if enable_escape else _EMPTY_,
|
||||
"test::@@@#22name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope": [
|
||||
None,
|
||||
"test_scope",
|
||||
"test@scope" if enable_escape else _EMPTY_,
|
||||
"test@@scope" if enable_escape else _EMPTY_,
|
||||
"test:scope" if enable_escape else _EMPTY_,
|
||||
"test:#:scope" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope_key": [
|
||||
None,
|
||||
"test_scope_key",
|
||||
"test:scope_key" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"sys_code": [
|
||||
None,
|
||||
"test_sys_code",
|
||||
"test#sys_code" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"user_name": [
|
||||
None,
|
||||
"test_user_name",
|
||||
"test%user_name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
}
|
||||
# Remove empty values
|
||||
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
|
||||
|
||||
# Generate all possible combinations
|
||||
combinations = product(*fields.values())
|
||||
|
||||
test_cases = []
|
||||
for combo in combinations:
|
||||
name, scope, scope_key, sys_code, user_name = combo
|
||||
|
||||
var_str = build_variable_string(
|
||||
{
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
},
|
||||
enable_escape=enable_escape,
|
||||
)
|
||||
|
||||
# Construct the expected output
|
||||
expected = {
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
}
|
||||
|
||||
test_cases.append((var_str, expected, enable_escape))
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def test_parse_variables():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_test_cases(enable_escape=False)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
assert result == expected_output, f"Test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_test_cases(enable_escape=True)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
print(f"input_str: {input_str}, expected_output: {expected_output}")
|
||||
result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
assert result == expected_output, f"Test case {i} failed with escape"
|
||||
|
||||
|
||||
def generate_build_test_cases(enable_escape=False):
|
||||
# Define possible values for each field, including special characters for escaping
|
||||
_EMPTY_ = "___EMPTY___"
|
||||
fields = {
|
||||
"name": [
|
||||
None,
|
||||
"test_name",
|
||||
"test:name" if enable_escape else _EMPTY_,
|
||||
"test::name" if enable_escape else _EMPTY_,
|
||||
"test\name" if enable_escape else _EMPTY_,
|
||||
"test\\name" if enable_escape else _EMPTY_,
|
||||
"test\:\#\@\%name" if enable_escape else _EMPTY_,
|
||||
"test\::\###\@@\%%name" if enable_escape else _EMPTY_,
|
||||
"test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_,
|
||||
"test\:#:name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_],
|
||||
"scope_key": [
|
||||
None,
|
||||
"test_scope_key",
|
||||
"test:scope_key" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"sys_code": [
|
||||
None,
|
||||
"test_sys_code",
|
||||
"test#sys_code" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"user_name": [
|
||||
None,
|
||||
"test_user_name",
|
||||
"test%user_name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
}
|
||||
# Remove empty values
|
||||
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
|
||||
|
||||
# Generate all possible combinations
|
||||
combinations = product(*fields.values())
|
||||
|
||||
test_cases = []
|
||||
|
||||
def escape_special_chars(s):
|
||||
if not enable_escape or s is None:
|
||||
return s
|
||||
return (
|
||||
s.replace(":", "\\:")
|
||||
.replace("@", "\\@")
|
||||
.replace("%", "\\%")
|
||||
.replace("#", "\\#")
|
||||
)
|
||||
|
||||
for combo in combinations:
|
||||
name, scope, scope_key, sys_code, user_name = combo
|
||||
|
||||
# Construct the input dictionary
|
||||
input_dict = {
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
}
|
||||
input_dict_with_escape = {
|
||||
k: escape_special_chars(v) for k, v in input_dict.items()
|
||||
}
|
||||
|
||||
# Construct the expected variable string
|
||||
expected_str = "${test_key"
|
||||
if name:
|
||||
expected_str += f":{input_dict_with_escape['name']}"
|
||||
if scope or scope_key:
|
||||
expected_str += "@"
|
||||
if scope:
|
||||
expected_str += input_dict_with_escape["scope"]
|
||||
if scope_key:
|
||||
expected_str += f":{input_dict_with_escape['scope_key']}"
|
||||
if sys_code:
|
||||
expected_str += f"#{input_dict_with_escape['sys_code']}"
|
||||
if user_name:
|
||||
expected_str += f"%{input_dict_with_escape['user_name']}"
|
||||
expected_str += "}"
|
||||
|
||||
test_cases.append((input_dict, expected_str, enable_escape))
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def test_build_variable_string():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_build_test_cases(enable_escape=False)
|
||||
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
|
||||
result = build_variable_string(input_dict, enable_escape=enable_escape)
|
||||
assert result == expected_str, f"Test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_build_test_cases(enable_escape=True)
|
||||
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
|
||||
print(f"input_dict: {input_dict}, expected_str: {expected_str}")
|
||||
result = build_variable_string(input_dict, enable_escape=enable_escape)
|
||||
assert result == expected_str, f"Test case {i} failed with escape"
|
||||
|
||||
|
||||
def test_variable_string_round_trip():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_test_cases(enable_escape=False)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
|
||||
assert (
|
||||
built_result == input_str
|
||||
), f"Round trip test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_test_cases(enable_escape=True)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
|
||||
assert built_result == input_str, f"Round trip test case {i} failed with escape"
|
||||
|
@@ -182,36 +182,18 @@ class VariablesIdentifier(ResourceIdentifier):
|
||||
if not self.key or not self.name or not self.scope:
|
||||
raise ValueError("Key, name, and scope are required.")
|
||||
|
||||
if any(
|
||||
self.identifier_split in key
|
||||
for key in [
|
||||
self.key,
|
||||
self.name,
|
||||
self.scope,
|
||||
self.scope_key,
|
||||
self.sys_code,
|
||||
self.user_name,
|
||||
]
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in "
|
||||
f"key, name, scope, scope_key, sys_code, user_name."
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return self.identifier_split.join(
|
||||
key or ""
|
||||
for key in [
|
||||
self.key,
|
||||
self.name,
|
||||
self.scope,
|
||||
self.scope_key,
|
||||
self.sys_code,
|
||||
self.user_name,
|
||||
]
|
||||
return build_variable_string(
|
||||
{
|
||||
"key": self.key,
|
||||
"name": self.name,
|
||||
"scope": self.scope,
|
||||
"scope_key": self.scope_key,
|
||||
"sys_code": self.sys_code,
|
||||
"user_name": self.user_name,
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
@@ -230,33 +212,30 @@ class VariablesIdentifier(ResourceIdentifier):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_str_identifier(
|
||||
cls, str_identifier: str, identifier_split: str = "@"
|
||||
) -> "VariablesIdentifier":
|
||||
def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier":
|
||||
"""Create a VariablesIdentifier from a string identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The string identifier.
|
||||
identifier_split (str): The identifier split.
|
||||
|
||||
Returns:
|
||||
VariablesIdentifier: The VariablesIdentifier.
|
||||
"""
|
||||
keys = str_identifier.split(identifier_split)
|
||||
if not keys:
|
||||
variable_dict = parse_variable(str_identifier)
|
||||
if not variable_dict:
|
||||
raise ValueError("Invalid string identifier.")
|
||||
if len(keys) < 2:
|
||||
if not variable_dict.get("key"):
|
||||
raise ValueError("Invalid string identifier, must have key")
|
||||
if not variable_dict.get("name"):
|
||||
raise ValueError("Invalid string identifier, must have name")
|
||||
if len(keys) < 3:
|
||||
raise ValueError("Invalid string identifier, must have scope")
|
||||
|
||||
return cls(
|
||||
key=keys[0],
|
||||
name=keys[1],
|
||||
scope=keys[2],
|
||||
scope_key=keys[3] if len(keys) > 3 else None,
|
||||
sys_code=keys[4] if len(keys) > 4 else None,
|
||||
user_name=keys[5] if len(keys) > 5 else None,
|
||||
key=variable_dict["key"],
|
||||
name=variable_dict["name"],
|
||||
scope=variable_dict.get("scope", "global"),
|
||||
scope_key=variable_dict.get("scope_key"),
|
||||
sys_code=variable_dict.get("sys_code"),
|
||||
user_name=variable_dict.get("user_name"),
|
||||
)
|
||||
|
||||
|
||||
@@ -402,6 +381,26 @@ class VariablesProvider(BaseComponent, ABC):
|
||||
"""Whether the variables provider support async."""
|
||||
return False
|
||||
|
||||
def _convert_to_value_type(self, var: StorageVariables):
|
||||
"""Convert the variable to the value type."""
|
||||
if var.value is None:
|
||||
return None
|
||||
if var.value_type == "str":
|
||||
return str(var.value)
|
||||
elif var.value_type == "int":
|
||||
return int(var.value)
|
||||
elif var.value_type == "float":
|
||||
return float(var.value)
|
||||
elif var.value_type == "bool":
|
||||
if var.value.lower() in ["true", "1"]:
|
||||
return True
|
||||
elif var.value.lower() in ["false", "0"]:
|
||||
return False
|
||||
else:
|
||||
return bool(var.value)
|
||||
else:
|
||||
return var.value
|
||||
|
||||
|
||||
class VariablesPlaceHolder:
|
||||
"""The variables place holder."""
|
||||
@@ -410,46 +409,20 @@ class VariablesPlaceHolder:
|
||||
self,
|
||||
param_name: str,
|
||||
full_key: str,
|
||||
value_type: str,
|
||||
default_value: Any = _EMPTY_DEFAULT_VALUE,
|
||||
):
|
||||
"""Initialize the variables place holder."""
|
||||
self.param_name = param_name
|
||||
self.full_key = full_key
|
||||
self.value_type = value_type
|
||||
self.default_value = default_value
|
||||
|
||||
def parse(self, variables_provider: VariablesProvider) -> Any:
|
||||
"""Parse the variables."""
|
||||
value = variables_provider.get(self.full_key, self.default_value)
|
||||
if value:
|
||||
return self._cast_to_type(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def _cast_to_type(self, value: Any) -> Any:
|
||||
if self.value_type == "str":
|
||||
return str(value)
|
||||
elif self.value_type == "int":
|
||||
return int(value)
|
||||
elif self.value_type == "float":
|
||||
return float(value)
|
||||
elif self.value_type == "bool":
|
||||
if value.lower() in ["true", "1"]:
|
||||
return True
|
||||
elif value.lower() in ["false", "0"]:
|
||||
return False
|
||||
else:
|
||||
return bool(value)
|
||||
else:
|
||||
return value
|
||||
return variables_provider.get(self.full_key, self.default_value)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
return (
|
||||
f"<VariablesPlaceHolder "
|
||||
f"{self.param_name} {self.full_key} {self.value_type}>"
|
||||
)
|
||||
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
||||
|
||||
|
||||
class StorageVariablesProvider(VariablesProvider):
|
||||
@@ -493,7 +466,7 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
and variable.salt
|
||||
):
|
||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||
return variable.value
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
@@ -676,3 +649,287 @@ class BuiltinVariablesProvider(VariablesProvider, ABC):
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support save.")
|
||||
|
||||
|
||||
def parse_variable(
|
||||
variable_str: str,
|
||||
enable_escape: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse the variable string.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
cases = [
|
||||
{
|
||||
"full_key": "${test_key:test_name@test_scope:test_scope_key}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": "test_name",
|
||||
"scope": "test_scope",
|
||||
"scope_key": "test_scope_key",
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"full_key": "${test_key#test_sys_code}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": None,
|
||||
"sys_code": "test_sys_code",
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"full_key": "${test_key@:test_scope_key}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": "test_scope_key",
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
for case in cases:
|
||||
assert parse_variable(case["full_key"]) == case["expected"]
|
||||
Args:
|
||||
variable_str (str): The variable string.
|
||||
enable_escape (bool): Whether to handle escaped characters.
|
||||
Returns:
|
||||
Dict[str, Any]: The parsed variable.
|
||||
"""
|
||||
if not variable_str.startswith("${") or not variable_str.endswith("}"):
|
||||
raise ValueError(
|
||||
"Invalid variable format, must start with '${' and end with '}'"
|
||||
)
|
||||
|
||||
# Remove the surrounding ${ and }
|
||||
content = variable_str[2:-1]
|
||||
|
||||
# Define placeholders for escaped characters
|
||||
placeholders = {
|
||||
r"\@": "__ESCAPED_AT__",
|
||||
r"\#": "__ESCAPED_HASH__",
|
||||
r"\%": "__ESCAPED_PERCENT__",
|
||||
r"\:": "__ESCAPED_COLON__",
|
||||
}
|
||||
|
||||
if enable_escape:
|
||||
# Replace escaped characters with placeholders
|
||||
for original, placeholder in placeholders.items():
|
||||
content = content.replace(original, placeholder)
|
||||
|
||||
# Initialize the result dictionary
|
||||
result: Dict[str, Optional[str]] = {
|
||||
"key": None,
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": None,
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
}
|
||||
|
||||
# Split the content by special characters
|
||||
parts = content.split("@")
|
||||
|
||||
# Parse key and name
|
||||
key_name = parts[0].split("#")[0].split("%")[0]
|
||||
if ":" in key_name:
|
||||
result["key"], result["name"] = key_name.split(":", 1)
|
||||
else:
|
||||
result["key"] = key_name
|
||||
|
||||
# Parse scope and scope_key
|
||||
if len(parts) > 1:
|
||||
scope_part = parts[1].split("#")[0].split("%")[0]
|
||||
if ":" in scope_part:
|
||||
result["scope"], result["scope_key"] = scope_part.split(":", 1)
|
||||
else:
|
||||
result["scope"] = scope_part
|
||||
|
||||
# Parse sys_code
|
||||
if "#" in content:
|
||||
result["sys_code"] = content.split("#", 1)[1].split("%")[0]
|
||||
|
||||
# Parse user_name
|
||||
if "%" in content:
|
||||
result["user_name"] = content.split("%", 1)[1]
|
||||
|
||||
if enable_escape:
|
||||
# Replace placeholders back with escaped characters
|
||||
reverse_placeholders = {v: k[1:] for k, v in placeholders.items()}
|
||||
for key, value in result.items():
|
||||
if value:
|
||||
for placeholder, original in reverse_placeholders.items():
|
||||
result[key] = result[key].replace( # type: ignore
|
||||
placeholder, original
|
||||
)
|
||||
|
||||
# Replace empty strings with None
|
||||
for key, value in result.items():
|
||||
if value == "":
|
||||
result[key] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _is_variable_format(value: str) -> bool:
|
||||
if not value.startswith("${") or not value.endswith("}"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_variable_string(variable_str: str) -> bool:
|
||||
"""Check if the given string is a variable string.
|
||||
|
||||
A valid variable string should start with "${" and end with "}", and contain key
|
||||
and name
|
||||
|
||||
Args:
|
||||
variable_str (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is a variable string, False otherwise.
|
||||
"""
|
||||
if not _is_variable_format(variable_str):
|
||||
return False
|
||||
try:
|
||||
variable_dict = parse_variable(variable_str)
|
||||
if not variable_dict.get("key"):
|
||||
return False
|
||||
if not variable_dict.get("name"):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_list_string(variable_str: str) -> bool:
|
||||
"""Check if the given string is a variable string.
|
||||
|
||||
A valid variable list string should start with "${" and end with "}", and contain
|
||||
key and not contain name
|
||||
|
||||
A valid variable list string means that the variable is a list of variables with the
|
||||
same key.
|
||||
|
||||
Args:
|
||||
variable_str (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is a variable string, False otherwise.
|
||||
"""
|
||||
if not _is_variable_format(variable_str):
|
||||
return False
|
||||
try:
|
||||
variable_dict = parse_variable(variable_str)
|
||||
if not variable_dict.get("key"):
|
||||
return False
|
||||
if variable_dict.get("name"):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def build_variable_string(
|
||||
variable_dict: Dict[str, Any],
|
||||
scope_sig: str = "@",
|
||||
sys_code_sig: str = "#",
|
||||
user_sig: str = "%",
|
||||
kv_sig: str = ":",
|
||||
enable_escape: bool = True,
|
||||
) -> str:
|
||||
"""Build a variable string from the given dictionary.
|
||||
|
||||
Args:
|
||||
variable_dict (Dict[str, Any]): The dictionary containing the variable details.
|
||||
scope_sig (str): The scope signature.
|
||||
sys_code_sig (str): The sys code signature.
|
||||
user_sig (str): The user signature.
|
||||
kv_sig (str): The key-value split signature.
|
||||
enable_escape (bool): Whether to escape special characters
|
||||
|
||||
Returns:
|
||||
str: The formatted variable string.
|
||||
|
||||
Examples:
|
||||
>>> build_variable_string(
|
||||
... {
|
||||
... "key": "test_key",
|
||||
... "name": "test_name",
|
||||
... "scope": "test_scope",
|
||||
... "scope_key": "test_scope_key",
|
||||
... "sys_code": "test_sys_code",
|
||||
... "user_name": "test_user",
|
||||
... }
|
||||
... )
|
||||
'${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"})
|
||||
'${test_key@:test_scope_key}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"})
|
||||
'${test_key#test_sys_code}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key"})
|
||||
'${test_key}'
|
||||
"""
|
||||
special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig}
|
||||
# Replace None with ""
|
||||
new_variable_dict = {key: value or "" for key, value in variable_dict.items()}
|
||||
|
||||
# Check if the variable_dict contains any special characters
|
||||
for key, value in new_variable_dict.items():
|
||||
if value != "" and any(char in value for char in special_chars):
|
||||
if enable_escape:
|
||||
# Escape special characters
|
||||
new_variable_dict[key] = (
|
||||
value.replace("@", "\\@")
|
||||
.replace("#", "\\#")
|
||||
.replace("%", "\\%")
|
||||
.replace(":", "\\:")
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{key} contains special characters, error value: {value}, special "
|
||||
f"characters: {special_chars}"
|
||||
)
|
||||
|
||||
key = new_variable_dict.get("key", "")
|
||||
name = new_variable_dict.get("name", "")
|
||||
scope = new_variable_dict.get("scope", "")
|
||||
scope_key = new_variable_dict.get("scope_key", "")
|
||||
sys_code = new_variable_dict.get("sys_code", "")
|
||||
user_name = new_variable_dict.get("user_name", "")
|
||||
|
||||
# Construct the base of the variable string
|
||||
variable_str = f"${{{key}"
|
||||
|
||||
# Add name if present
|
||||
if name:
|
||||
variable_str += f":{name}"
|
||||
|
||||
# Add scope and scope_key if present
|
||||
if scope or scope_key:
|
||||
variable_str += f"@{scope}"
|
||||
if scope_key:
|
||||
variable_str += f":{scope_key}"
|
||||
|
||||
# Add sys_code if present
|
||||
if sys_code:
|
||||
variable_str += f"#{sys_code}"
|
||||
|
||||
# Add user_name if present
|
||||
if user_name:
|
||||
variable_str += f"%{user_name}"
|
||||
|
||||
# Close the variable string
|
||||
variable_str += "}"
|
||||
|
||||
return variable_str
|
||||
|
Reference in New Issue
Block a user