feat(core): Support complex variables parsing

This commit is contained in:
Fangyin Cheng
2024-08-11 16:09:26 +08:00
parent 3b4b34ca3c
commit a348e6c0d7
8 changed files with 827 additions and 93 deletions

View File

@@ -380,27 +380,40 @@ class Parameter(TypeMetadata, Serializable):
return values
@classmethod
def _covert_to_real_type(cls, type_cls: str, v: Any):
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
if type_cls and v is not None:
typed_value: Any = v
try:
# Try to convert the value to the type.
if type_cls == "builtins.str":
return str(v)
typed_value = str(v)
elif type_cls == "builtins.int":
return int(v)
typed_value = int(v)
elif type_cls == "builtins.float":
return float(v)
typed_value = float(v)
elif type_cls == "builtins.bool":
if str(v).lower() in ["false", "0", "", "no", "off"]:
return False
return bool(v)
typed_value = bool(v)
return typed_value
except ValueError:
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
return v
def get_typed_value(self) -> Any:
"""Get the typed value."""
return self._covert_to_real_type(self.type_cls, self.value)
"""Get the typed value.
Returns:
Any: The typed value. VariablesPlaceHolder if the value is a variable
string. Otherwise, the real type value.
"""
from ...interface.variables import VariablesPlaceHolder, is_variable_string
is_variables = is_variable_string(self.value) if self.value else False
if is_variables and self.value is not None and isinstance(self.value, str):
return VariablesPlaceHolder(self.name, self.value)
else:
return self._covert_to_real_type(self.type_cls, self.value)
def get_typed_default(self) -> Any:
"""Get the typed default."""

View File

@@ -0,0 +1,223 @@
import json
from typing import cast
import pytest
from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
from dbgpt.core.awel.flow import (
IOField,
OperatorCategory,
Parameter,
VariablesDynamicOptions,
ViewMetadata,
ui,
)
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel
from ...tests.conftest import variables_provider
class MyVariablesOperator(MapOperator[str, str]):
metadata = ViewMetadata(
label="My Test Variables Operator",
name="my_test_variables_operator",
category=OperatorCategory.EXAMPLE,
description="An example flow operator that includes a variables option.",
parameters=[
Parameter.build_from(
"OpenAI API Key",
"openai_api_key",
type=str,
placeholder="Please select the OpenAI API key",
description="The OpenAI API key to use.",
options=VariablesDynamicOptions(),
ui=ui.UIPasswordInput(
key="dbgpt.model.openai.api_key",
),
),
Parameter.build_from(
"Model",
"model",
type=str,
placeholder="Please select the model",
description="The model to use.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.model.openai.model",
),
),
],
inputs=[
IOField.build_from(
"User Name",
"user_name",
str,
description="The name of the user.",
),
],
outputs=[
IOField.build_from(
"Model info",
"model",
str,
description="The model info.",
),
],
)
def __init__(self, openai_api_key: str, model: str, **kwargs):
super().__init__(**kwargs)
self._openai_api_key = openai_api_key
self._model = model
async def map(self, user_name: str) -> str:
dict_dict = {
"openai_api_key": self._openai_api_key,
"model": self._model,
}
json_data = json.dumps(dict_dict, ensure_ascii=False)
return "Your name is %s, and your model info is %s." % (user_name, json_data)
class EndOperator(MapOperator[str, str]):
metadata = ViewMetadata(
label="End Operator",
name="end_operator",
category=OperatorCategory.EXAMPLE,
description="An example flow operator that ends the flow.",
parameters=[],
inputs=[
IOField.build_from(
"Input",
"input",
str,
description="The input to the end operator.",
),
],
outputs=[
IOField.build_from(
"Output",
"output",
str,
description="The output of the end operator.",
),
],
)
async def map(self, input: str) -> str:
return f"End operator received input: {input}"
@pytest.fixture
def json_flow():
operators = [MyVariablesOperator, EndOperator]
metadata_list = [operator.metadata.to_dict() for operator in operators]
node_names = {}
name_to_parameters_dict = {
"my_test_variables_operator": {
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
"model": "${dbgpt.model.openai.model:default_model@global}",
}
}
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
ui_nodes = []
for metadata in metadata_list:
type_name = metadata["type_name"]
name = metadata["name"]
id = metadata["id"]
if type_name in node_names:
raise ValueError(f"Duplicate node type name: {type_name}")
# Replace id to flow data id.
metadata["id"] = f"{id}_0"
parameters = metadata["parameters"]
parameters_dict = name_to_parameters_dict.get(name, {})
for parameter in parameters:
parameter_name = parameter["name"]
if parameter_name in parameters_dict:
parameter["value"] = parameters_dict[parameter_name]
ui_nodes.append(
{
"width": 288,
"height": 352,
"id": metadata["id"],
"position": {
"x": -149.98120112708142,
"y": 666.9468497341901,
"zoom": 0.0,
},
"type": "customNode",
"position_absolute": {
"x": -149.98120112708142,
"y": 666.9468497341901,
"zoom": 0.0,
},
"data": metadata,
}
)
ui_edges = []
source_id = name_to_metadata_dict["my_test_variables_operator"]["id"]
target_id = name_to_metadata_dict["end_operator"]["id"]
ui_edges.append(
{
"source": source_id,
"target": target_id,
"source_order": 0,
"target_order": 0,
"id": f"{source_id}|{target_id}",
"source_handle": f"{source_id}|outputs|0",
"target_handle": f"{target_id}|inputs|0",
"type": "buttonedge",
}
)
return {
"nodes": ui_nodes,
"edges": ui_edges,
"viewport": {
"x": 509.2191773722104,
"y": -66.11286175905718,
"zoom": 1.252741002590748,
},
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"variables_provider",
[
(
{
"vars": {
"openai_api_key": {
"key": "${dbgpt.model.openai.api_key:my_key@global}",
"value": "my_openai_api_key",
"value_type": "str",
"category": "secret",
},
"model": {
"key": "${dbgpt.model.openai.model:default_model@global}",
"value": "GPT-4o",
"value_type": "str",
},
}
}
),
],
indirect=["variables_provider"],
)
async def test_build_flow(json_flow, variables_provider):
DAGVar.set_variables_provider(variables_provider)
flow_data = FlowData(**json_flow)
flow_panel = FlowPanel(
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed"
)
factory = FlowFactory()
dag = factory.build(flow_panel)
leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
result = await leaf_node.call("Alice")
assert (
result
== "End operator received input: Your name is Alice, and your model info is "
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.'
)

View File

@@ -19,7 +19,7 @@ _UI_TYPE = Literal[
"time_picker",
"tree_select",
"upload",
"variable",
"variables",
"password",
"code_editor",
]

View File

@@ -380,6 +380,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not self._variables_provider:
return
# TODO: Resolve variables parallel
for attr, value in self.__dict__.items():
if isinstance(value, VariablesPlaceHolder):
resolved_value = await self.blocking_func_to_async(

View File

@@ -1,17 +1,15 @@
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager
from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from .. import (
DAGContext,
DefaultWorkflowRunner,
InputOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator
@@ -102,3 +100,32 @@ async def stream_input_nodes(request):
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@asynccontextmanager
async def _create_variables(**kwargs):
vp = StorageVariablesProvider()
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():
key = param_var.get("key")
value = param_var.get("value")
value_type = param_var.get("value_type")
category = param_var.get("category", "common")
id = VariablesIdentifier.from_str_identifier(key)
vp.save(
StorageVariables.from_identifier(
id, value, value_type, label="", category=category
)
)
else:
raise ValueError("vars is required.")
yield vp
@pytest_asyncio.fixture
async def variables_provider(request):
param = getattr(request, "param", {})
async with _create_variables(**param) as vp:
yield vp

View File

@@ -54,7 +54,7 @@ async def _create_variables(**kwargs):
id, value, value_type, label="", category=category
)
)
variables[param_key] = VariablesPlaceHolder(param_key, key, value_type)
variables[param_key] = VariablesPlaceHolder(param_key, key)
else:
raise ValueError("vars is required.")
@@ -85,17 +85,17 @@ async def test_default_dag(default_dag: DAG):
{
"vars": {
"int_var": {
"key": "int_key@my_int_var@global",
"key": "${int_key:my_int_var@global}",
"value": 0,
"value_type": "int",
},
"str_var": {
"key": "str_key@my_str_var@global",
"key": "${str_key:my_str_var@global}",
"value": "1",
"value_type": "str",
},
"secret": {
"key": "secret_key@my_secret_var@global",
"key": "${secret_key:my_secret_var@global}",
"value": "2131sdsdf",
"value_type": "str",
"category": "secret",

View File

@@ -1,5 +1,6 @@
import base64
import os
from itertools import product
from cryptography.fernet import Fernet
@@ -10,6 +11,8 @@ from ..variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
build_variable_string,
parse_variable,
)
@@ -46,7 +49,7 @@ def test_storage_variables_provider():
encryption = SimpleEncryption()
provider = StorageVariablesProvider(storage, encryption)
full_key = "key@name@global"
full_key = "${key:name@global}"
value = "secret_value"
value_type = "str"
label = "test_label"
@@ -63,7 +66,7 @@ def test_storage_variables_provider():
def test_variables_identifier():
full_key = "key@name@global@scope_key@sys_code@user_name"
full_key = "${key:name@global:scope_key#sys_code%user_name}"
identifier = VariablesIdentifier.from_str_identifier(full_key)
assert identifier.key == "key"
@@ -112,3 +115,213 @@ def test_storage_variables():
assert dict_representation["value_type"] == value_type
assert dict_representation["category"] == category
assert dict_representation["scope"] == scope
def generate_test_cases(enable_escape=False):
# Define possible values for each field, including special characters for escaping
_EMPTY_ = "___EMPTY___"
fields = {
"name": [
None,
"test_name",
"test:name" if enable_escape else _EMPTY_,
"test::name" if enable_escape else _EMPTY_,
"test#name" if enable_escape else _EMPTY_,
"test##name" if enable_escape else _EMPTY_,
"test::@@@#22name" if enable_escape else _EMPTY_,
],
"scope": [
None,
"test_scope",
"test@scope" if enable_escape else _EMPTY_,
"test@@scope" if enable_escape else _EMPTY_,
"test:scope" if enable_escape else _EMPTY_,
"test:#:scope" if enable_escape else _EMPTY_,
],
"scope_key": [
None,
"test_scope_key",
"test:scope_key" if enable_escape else _EMPTY_,
],
"sys_code": [
None,
"test_sys_code",
"test#sys_code" if enable_escape else _EMPTY_,
],
"user_name": [
None,
"test_user_name",
"test%user_name" if enable_escape else _EMPTY_,
],
}
# Remove empty values
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
# Generate all possible combinations
combinations = product(*fields.values())
test_cases = []
for combo in combinations:
name, scope, scope_key, sys_code, user_name = combo
var_str = build_variable_string(
{
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
},
enable_escape=enable_escape,
)
# Construct the expected output
expected = {
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
}
test_cases.append((var_str, expected, enable_escape))
return test_cases
def test_parse_variables():
# Run test cases without escape
test_cases = generate_test_cases(enable_escape=False)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
result = parse_variable(input_str, enable_escape=enable_escape)
assert result == expected_output, f"Test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_test_cases(enable_escape=True)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
print(f"input_str: {input_str}, expected_output: {expected_output}")
result = parse_variable(input_str, enable_escape=enable_escape)
assert result == expected_output, f"Test case {i} failed with escape"
def generate_build_test_cases(enable_escape=False):
# Define possible values for each field, including special characters for escaping
_EMPTY_ = "___EMPTY___"
fields = {
"name": [
None,
"test_name",
"test:name" if enable_escape else _EMPTY_,
"test::name" if enable_escape else _EMPTY_,
"test\name" if enable_escape else _EMPTY_,
"test\\name" if enable_escape else _EMPTY_,
"test\:\#\@\%name" if enable_escape else _EMPTY_,
"test\::\###\@@\%%name" if enable_escape else _EMPTY_,
"test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_,
"test\:#:name" if enable_escape else _EMPTY_,
],
"scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_],
"scope_key": [
None,
"test_scope_key",
"test:scope_key" if enable_escape else _EMPTY_,
],
"sys_code": [
None,
"test_sys_code",
"test#sys_code" if enable_escape else _EMPTY_,
],
"user_name": [
None,
"test_user_name",
"test%user_name" if enable_escape else _EMPTY_,
],
}
# Remove empty values
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
# Generate all possible combinations
combinations = product(*fields.values())
test_cases = []
def escape_special_chars(s):
if not enable_escape or s is None:
return s
return (
s.replace(":", "\\:")
.replace("@", "\\@")
.replace("%", "\\%")
.replace("#", "\\#")
)
for combo in combinations:
name, scope, scope_key, sys_code, user_name = combo
# Construct the input dictionary
input_dict = {
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
}
input_dict_with_escape = {
k: escape_special_chars(v) for k, v in input_dict.items()
}
# Construct the expected variable string
expected_str = "${test_key"
if name:
expected_str += f":{input_dict_with_escape['name']}"
if scope or scope_key:
expected_str += "@"
if scope:
expected_str += input_dict_with_escape["scope"]
if scope_key:
expected_str += f":{input_dict_with_escape['scope_key']}"
if sys_code:
expected_str += f"#{input_dict_with_escape['sys_code']}"
if user_name:
expected_str += f"%{input_dict_with_escape['user_name']}"
expected_str += "}"
test_cases.append((input_dict, expected_str, enable_escape))
return test_cases
def test_build_variable_string():
# Run test cases without escape
test_cases = generate_build_test_cases(enable_escape=False)
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
result = build_variable_string(input_dict, enable_escape=enable_escape)
assert result == expected_str, f"Test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_build_test_cases(enable_escape=True)
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
print(f"input_dict: {input_dict}, expected_str: {expected_str}")
result = build_variable_string(input_dict, enable_escape=enable_escape)
assert result == expected_str, f"Test case {i} failed with escape"
def test_variable_string_round_trip():
# Run test cases without escape
test_cases = generate_test_cases(enable_escape=False)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
assert (
built_result == input_str
), f"Round trip test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_test_cases(enable_escape=True)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
assert built_result == input_str, f"Round trip test case {i} failed with escape"

View File

@@ -182,36 +182,18 @@ class VariablesIdentifier(ResourceIdentifier):
if not self.key or not self.name or not self.scope:
raise ValueError("Key, name, and scope are required.")
if any(
self.identifier_split in key
for key in [
self.key,
self.name,
self.scope,
self.scope_key,
self.sys_code,
self.user_name,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in "
f"key, name, scope, scope_key, sys_code, user_name."
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key or ""
for key in [
self.key,
self.name,
self.scope,
self.scope_key,
self.sys_code,
self.user_name,
]
return build_variable_string(
{
"key": self.key,
"name": self.name,
"scope": self.scope,
"scope_key": self.scope_key,
"sys_code": self.sys_code,
"user_name": self.user_name,
}
)
def to_dict(self) -> Dict:
@@ -230,33 +212,30 @@ class VariablesIdentifier(ResourceIdentifier):
}
@classmethod
def from_str_identifier(
cls, str_identifier: str, identifier_split: str = "@"
) -> "VariablesIdentifier":
def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier":
"""Create a VariablesIdentifier from a string identifier.
Args:
str_identifier (str): The string identifier.
identifier_split (str): The identifier split.
Returns:
VariablesIdentifier: The VariablesIdentifier.
"""
keys = str_identifier.split(identifier_split)
if not keys:
variable_dict = parse_variable(str_identifier)
if not variable_dict:
raise ValueError("Invalid string identifier.")
if len(keys) < 2:
if not variable_dict.get("key"):
raise ValueError("Invalid string identifier, must have key")
if not variable_dict.get("name"):
raise ValueError("Invalid string identifier, must have name")
if len(keys) < 3:
raise ValueError("Invalid string identifier, must have scope")
return cls(
key=keys[0],
name=keys[1],
scope=keys[2],
scope_key=keys[3] if len(keys) > 3 else None,
sys_code=keys[4] if len(keys) > 4 else None,
user_name=keys[5] if len(keys) > 5 else None,
key=variable_dict["key"],
name=variable_dict["name"],
scope=variable_dict.get("scope", "global"),
scope_key=variable_dict.get("scope_key"),
sys_code=variable_dict.get("sys_code"),
user_name=variable_dict.get("user_name"),
)
@@ -402,6 +381,26 @@ class VariablesProvider(BaseComponent, ABC):
"""Whether the variables provider support async."""
return False
def _convert_to_value_type(self, var: StorageVariables):
"""Convert the variable to the value type."""
if var.value is None:
return None
if var.value_type == "str":
return str(var.value)
elif var.value_type == "int":
return int(var.value)
elif var.value_type == "float":
return float(var.value)
elif var.value_type == "bool":
if var.value.lower() in ["true", "1"]:
return True
elif var.value.lower() in ["false", "0"]:
return False
else:
return bool(var.value)
else:
return var.value
class VariablesPlaceHolder:
"""The variables place holder."""
@@ -410,46 +409,20 @@ class VariablesPlaceHolder:
self,
param_name: str,
full_key: str,
value_type: str,
default_value: Any = _EMPTY_DEFAULT_VALUE,
):
"""Initialize the variables place holder."""
self.param_name = param_name
self.full_key = full_key
self.value_type = value_type
self.default_value = default_value
def parse(self, variables_provider: VariablesProvider) -> Any:
"""Parse the variables."""
value = variables_provider.get(self.full_key, self.default_value)
if value:
return self._cast_to_type(value)
else:
return value
def _cast_to_type(self, value: Any) -> Any:
if self.value_type == "str":
return str(value)
elif self.value_type == "int":
return int(value)
elif self.value_type == "float":
return float(value)
elif self.value_type == "bool":
if value.lower() in ["true", "1"]:
return True
elif value.lower() in ["false", "0"]:
return False
else:
return bool(value)
else:
return value
return variables_provider.get(self.full_key, self.default_value)
def __repr__(self):
"""Return the representation of the variables place holder."""
return (
f"<VariablesPlaceHolder "
f"{self.param_name} {self.full_key} {self.value_type}>"
)
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
class StorageVariablesProvider(VariablesProvider):
@@ -493,7 +466,7 @@ class StorageVariablesProvider(VariablesProvider):
and variable.salt
):
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return variable.value
return self._convert_to_value_type(variable)
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
@@ -676,3 +649,287 @@ class BuiltinVariablesProvider(VariablesProvider, ABC):
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support save.")
def parse_variable(
variable_str: str,
enable_escape: bool = True,
) -> Dict[str, Any]:
"""Parse the variable string.
Examples:
.. code-block:: python
cases = [
{
"full_key": "${test_key:test_name@test_scope:test_scope_key}",
"expected": {
"key": "test_key",
"name": "test_name",
"scope": "test_scope",
"scope_key": "test_scope_key",
"sys_code": None,
"user_name": None,
},
},
{
"full_key": "${test_key#test_sys_code}",
"expected": {
"key": "test_key",
"name": None,
"scope": None,
"scope_key": None,
"sys_code": "test_sys_code",
"user_name": None,
},
},
{
"full_key": "${test_key@:test_scope_key}",
"expected": {
"key": "test_key",
"name": None,
"scope": None,
"scope_key": "test_scope_key",
"sys_code": None,
"user_name": None,
},
},
]
for case in cases:
assert parse_variable(case["full_key"]) == case["expected"]
Args:
variable_str (str): The variable string.
enable_escape (bool): Whether to handle escaped characters.
Returns:
Dict[str, Any]: The parsed variable.
"""
if not variable_str.startswith("${") or not variable_str.endswith("}"):
raise ValueError(
"Invalid variable format, must start with '${' and end with '}'"
)
# Remove the surrounding ${ and }
content = variable_str[2:-1]
# Define placeholders for escaped characters
placeholders = {
r"\@": "__ESCAPED_AT__",
r"\#": "__ESCAPED_HASH__",
r"\%": "__ESCAPED_PERCENT__",
r"\:": "__ESCAPED_COLON__",
}
if enable_escape:
# Replace escaped characters with placeholders
for original, placeholder in placeholders.items():
content = content.replace(original, placeholder)
# Initialize the result dictionary
result: Dict[str, Optional[str]] = {
"key": None,
"name": None,
"scope": None,
"scope_key": None,
"sys_code": None,
"user_name": None,
}
# Split the content by special characters
parts = content.split("@")
# Parse key and name
key_name = parts[0].split("#")[0].split("%")[0]
if ":" in key_name:
result["key"], result["name"] = key_name.split(":", 1)
else:
result["key"] = key_name
# Parse scope and scope_key
if len(parts) > 1:
scope_part = parts[1].split("#")[0].split("%")[0]
if ":" in scope_part:
result["scope"], result["scope_key"] = scope_part.split(":", 1)
else:
result["scope"] = scope_part
# Parse sys_code
if "#" in content:
result["sys_code"] = content.split("#", 1)[1].split("%")[0]
# Parse user_name
if "%" in content:
result["user_name"] = content.split("%", 1)[1]
if enable_escape:
# Replace placeholders back with escaped characters
reverse_placeholders = {v: k[1:] for k, v in placeholders.items()}
for key, value in result.items():
if value:
for placeholder, original in reverse_placeholders.items():
result[key] = result[key].replace( # type: ignore
placeholder, original
)
# Replace empty strings with None
for key, value in result.items():
if value == "":
result[key] = None
return result
def _is_variable_format(value: str) -> bool:
if not value.startswith("${") or not value.endswith("}"):
return False
return True
def is_variable_string(variable_str: str) -> bool:
"""Check if the given string is a variable string.
A valid variable string should start with "${" and end with "}", and contain key
and name
Args:
variable_str (str): The string to check.
Returns:
bool: True if the string is a variable string, False otherwise.
"""
if not _is_variable_format(variable_str):
return False
try:
variable_dict = parse_variable(variable_str)
if not variable_dict.get("key"):
return False
if not variable_dict.get("name"):
return False
return True
except Exception:
return False
def is_variable_list_string(variable_str: str) -> bool:
"""Check if the given string is a variable string.
A valid variable list string should start with "${" and end with "}", and contain
key and not contain name
A valid variable list string means that the variable is a list of variables with the
same key.
Args:
variable_str (str): The string to check.
Returns:
bool: True if the string is a variable string, False otherwise.
"""
if not _is_variable_format(variable_str):
return False
try:
variable_dict = parse_variable(variable_str)
if not variable_dict.get("key"):
return False
if variable_dict.get("name"):
return False
return True
except Exception:
return False
def build_variable_string(
variable_dict: Dict[str, Any],
scope_sig: str = "@",
sys_code_sig: str = "#",
user_sig: str = "%",
kv_sig: str = ":",
enable_escape: bool = True,
) -> str:
"""Build a variable string from the given dictionary.
Args:
variable_dict (Dict[str, Any]): The dictionary containing the variable details.
scope_sig (str): The scope signature.
sys_code_sig (str): The sys code signature.
user_sig (str): The user signature.
kv_sig (str): The key-value split signature.
enable_escape (bool): Whether to escape special characters
Returns:
str: The formatted variable string.
Examples:
>>> build_variable_string(
... {
... "key": "test_key",
... "name": "test_name",
... "scope": "test_scope",
... "scope_key": "test_scope_key",
... "sys_code": "test_sys_code",
... "user_name": "test_user",
... }
... )
'${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}'
>>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"})
'${test_key@:test_scope_key}'
>>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"})
'${test_key#test_sys_code}'
>>> build_variable_string({"key": "test_key"})
'${test_key}'
"""
special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig}
# Replace None with ""
new_variable_dict = {key: value or "" for key, value in variable_dict.items()}
# Check if the variable_dict contains any special characters
for key, value in new_variable_dict.items():
if value != "" and any(char in value for char in special_chars):
if enable_escape:
# Escape special characters
new_variable_dict[key] = (
value.replace("@", "\\@")
.replace("#", "\\#")
.replace("%", "\\%")
.replace(":", "\\:")
)
else:
raise ValueError(
f"{key} contains special characters, error value: {value}, special "
f"characters: {special_chars}"
)
key = new_variable_dict.get("key", "")
name = new_variable_dict.get("name", "")
scope = new_variable_dict.get("scope", "")
scope_key = new_variable_dict.get("scope_key", "")
sys_code = new_variable_dict.get("sys_code", "")
user_name = new_variable_dict.get("user_name", "")
# Construct the base of the variable string
variable_str = f"${{{key}"
# Add name if present
if name:
variable_str += f":{name}"
# Add scope and scope_key if present
if scope or scope_key:
variable_str += f"@{scope}"
if scope_key:
variable_str += f":{scope_key}"
# Add sys_code if present
if sys_code:
variable_str += f"#{sys_code}"
# Add user_name if present
if user_name:
variable_str += f"%{user_name}"
# Close the variable string
variable_str += "}"
return variable_str