From a348e6c0d70e4c5cfcb82f695164a894548c46cf Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sun, 11 Aug 2024 16:09:26 +0800 Subject: [PATCH] feat(core): Support complex variables parsing --- dbgpt/core/awel/flow/base.py | 27 +- .../awel/flow/tests/test_flow_variables.py | 223 ++++++++++ dbgpt/core/awel/flow/ui.py | 2 +- dbgpt/core/awel/operators/base.py | 1 + dbgpt/core/awel/tests/conftest.py | 43 +- dbgpt/core/awel/tests/test_dag_variables.py | 8 +- dbgpt/core/interface/tests/test_variables.py | 217 +++++++++- dbgpt/core/interface/variables.py | 399 ++++++++++++++---- 8 files changed, 827 insertions(+), 93 deletions(-) create mode 100644 dbgpt/core/awel/flow/tests/test_flow_variables.py diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 57081420e..61e0dfa75 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -380,27 +380,40 @@ class Parameter(TypeMetadata, Serializable): return values @classmethod - def _covert_to_real_type(cls, type_cls: str, v: Any): + def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: if type_cls and v is not None: + typed_value: Any = v try: # Try to convert the value to the type. if type_cls == "builtins.str": - return str(v) + typed_value = str(v) elif type_cls == "builtins.int": - return int(v) + typed_value = int(v) elif type_cls == "builtins.float": - return float(v) + typed_value = float(v) elif type_cls == "builtins.bool": if str(v).lower() in ["false", "0", "", "no", "off"]: return False - return bool(v) + typed_value = bool(v) + return typed_value except ValueError: raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") return v def get_typed_value(self) -> Any: - """Get the typed value.""" - return self._covert_to_real_type(self.type_cls, self.value) + """Get the typed value. + + Returns: + Any: The typed value. VariablesPlaceHolder if the value is a variable + string. Otherwise, the real type value. + """ + from ...interface.variables import VariablesPlaceHolder, is_variable_string + + is_variables = is_variable_string(self.value) if self.value else False + if is_variables and self.value is not None and isinstance(self.value, str): + return VariablesPlaceHolder(self.name, self.value) + else: + return self._covert_to_real_type(self.type_cls, self.value) def get_typed_default(self) -> Any: """Get the typed default.""" diff --git a/dbgpt/core/awel/flow/tests/test_flow_variables.py b/dbgpt/core/awel/flow/tests/test_flow_variables.py new file mode 100644 index 000000000..eaa548b09 --- /dev/null +++ b/dbgpt/core/awel/flow/tests/test_flow_variables.py @@ -0,0 +1,223 @@ +import json +from typing import cast + +import pytest + +from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator +from dbgpt.core.awel.flow import ( + IOField, + OperatorCategory, + Parameter, + VariablesDynamicOptions, + ViewMetadata, + ui, +) +from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel + +from ...tests.conftest import variables_provider + + +class MyVariablesOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="My Test Variables Operator", + name="my_test_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__(self, openai_api_key: str, model: str, **kwargs): + super().__init__(**kwargs) + self._openai_api_key = openai_api_key + self._model = model + + async def map(self, user_name: str) -> str: + dict_dict = { + "openai_api_key": self._openai_api_key, + "model": self._model, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) + + +class EndOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="End Operator", + name="end_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that ends the flow.", + parameters=[], + inputs=[ + IOField.build_from( + "Input", + "input", + str, + description="The input to the end operator.", + ), + ], + outputs=[ + IOField.build_from( + "Output", + "output", + str, + description="The output of the end operator.", + ), + ], + ) + + async def map(self, input: str) -> str: + return f"End operator received input: {input}" + + +@pytest.fixture +def json_flow(): + operators = [MyVariablesOperator, EndOperator] + metadata_list = [operator.metadata.to_dict() for operator in operators] + node_names = {} + name_to_parameters_dict = { + "my_test_variables_operator": { + "openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}", + "model": "${dbgpt.model.openai.model:default_model@global}", + } + } + name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list} + ui_nodes = [] + for metadata in metadata_list: + type_name = metadata["type_name"] + name = metadata["name"] + id = metadata["id"] + if type_name in node_names: + raise ValueError(f"Duplicate node type name: {type_name}") + # Replace id to flow data id. + metadata["id"] = f"{id}_0" + parameters = metadata["parameters"] + parameters_dict = name_to_parameters_dict.get(name, {}) + for parameter in parameters: + parameter_name = parameter["name"] + if parameter_name in parameters_dict: + parameter["value"] = parameters_dict[parameter_name] + ui_nodes.append( + { + "width": 288, + "height": 352, + "id": metadata["id"], + "position": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "type": "customNode", + "position_absolute": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "data": metadata, + } + ) + + ui_edges = [] + source_id = name_to_metadata_dict["my_test_variables_operator"]["id"] + target_id = name_to_metadata_dict["end_operator"]["id"] + ui_edges.append( + { + "source": source_id, + "target": target_id, + "source_order": 0, + "target_order": 0, + "id": f"{source_id}|{target_id}", + "source_handle": f"{source_id}|outputs|0", + "target_handle": f"{target_id}|inputs|0", + "type": "buttonedge", + } + ) + return { + "nodes": ui_nodes, + "edges": ui_edges, + "viewport": { + "x": 509.2191773722104, + "y": -66.11286175905718, + "zoom": 1.252741002590748, + }, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_provider", + [ + ( + { + "vars": { + "openai_api_key": { + "key": "${dbgpt.model.openai.api_key:my_key@global}", + "value": "my_openai_api_key", + "value_type": "str", + "category": "secret", + }, + "model": { + "key": "${dbgpt.model.openai.model:default_model@global}", + "value": "GPT-4o", + "value_type": "str", + }, + } + } + ), + ], + indirect=["variables_provider"], +) +async def test_build_flow(json_flow, variables_provider): + DAGVar.set_variables_provider(variables_provider) + flow_data = FlowData(**json_flow) + flow_panel = FlowPanel( + label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed" + ) + factory = FlowFactory() + dag = factory.build(flow_panel) + + leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0]) + result = await leaf_node.call("Alice") + assert ( + result + == "End operator received input: Your name is Alice, and your model info is " + '{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.' + ) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 66b413a9f..875547e9a 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -19,7 +19,7 @@ _UI_TYPE = Literal[ "time_picker", "tree_select", "upload", - "variable", + "variables", "password", "code_editor", ] diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index 77f056042..0933a4547 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -380,6 +380,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if not self._variables_provider: return + # TODO: Resolve variables parallel for attr, value in self.__dict__.items(): if isinstance(value, VariablesPlaceHolder): resolved_value = await self.blocking_func_to_async( diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index d68ddcfc8..607783028 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -1,17 +1,15 @@ -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from typing import AsyncIterator, List import pytest import pytest_asyncio -from .. import ( - DAGContext, - DefaultWorkflowRunner, - InputOperator, - SimpleInputSource, - TaskState, - WorkflowRunner, +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, ) +from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource from ..task.task_impl import _is_async_iterator @@ -102,3 +100,32 @@ async def stream_input_nodes(request): param["is_stream"] = True async with _create_input_node(**param) as input_nodes: yield input_nodes + + +@asynccontextmanager +async def _create_variables(**kwargs): + vp = StorageVariablesProvider() + vars = kwargs.get("vars") + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + vp.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + else: + raise ValueError("vars is required.") + + yield vp + + +@pytest_asyncio.fixture +async def variables_provider(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as vp: + yield vp diff --git a/dbgpt/core/awel/tests/test_dag_variables.py b/dbgpt/core/awel/tests/test_dag_variables.py index 88c9b6660..8bdb29143 100644 --- a/dbgpt/core/awel/tests/test_dag_variables.py +++ b/dbgpt/core/awel/tests/test_dag_variables.py @@ -54,7 +54,7 @@ async def _create_variables(**kwargs): id, value, value_type, label="", category=category ) ) - variables[param_key] = VariablesPlaceHolder(param_key, key, value_type) + variables[param_key] = VariablesPlaceHolder(param_key, key) else: raise ValueError("vars is required.") @@ -85,17 +85,17 @@ async def test_default_dag(default_dag: DAG): { "vars": { "int_var": { - "key": "int_key@my_int_var@global", + "key": "${int_key:my_int_var@global}", "value": 0, "value_type": "int", }, "str_var": { - "key": "str_key@my_str_var@global", + "key": "${str_key:my_str_var@global}", "value": "1", "value_type": "str", }, "secret": { - "key": "secret_key@my_secret_var@global", + "key": "${secret_key:my_secret_var@global}", "value": "2131sdsdf", "value_type": "str", "category": "secret", diff --git a/dbgpt/core/interface/tests/test_variables.py b/dbgpt/core/interface/tests/test_variables.py index 3b7ab8157..313657b4e 100644 --- a/dbgpt/core/interface/tests/test_variables.py +++ b/dbgpt/core/interface/tests/test_variables.py @@ -1,5 +1,6 @@ import base64 import os +from itertools import product from cryptography.fernet import Fernet @@ -10,6 +11,8 @@ from ..variables import ( StorageVariables, StorageVariablesProvider, VariablesIdentifier, + build_variable_string, + parse_variable, ) @@ -46,7 +49,7 @@ def test_storage_variables_provider(): encryption = SimpleEncryption() provider = StorageVariablesProvider(storage, encryption) - full_key = "key@name@global" + full_key = "${key:name@global}" value = "secret_value" value_type = "str" label = "test_label" @@ -63,7 +66,7 @@ def test_storage_variables_provider(): def test_variables_identifier(): - full_key = "key@name@global@scope_key@sys_code@user_name" + full_key = "${key:name@global:scope_key#sys_code%user_name}" identifier = VariablesIdentifier.from_str_identifier(full_key) assert identifier.key == "key" @@ -112,3 +115,213 @@ def test_storage_variables(): assert dict_representation["value_type"] == value_type assert dict_representation["category"] == category assert dict_representation["scope"] == scope + + +def generate_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test#name" if enable_escape else _EMPTY_, + "test##name" if enable_escape else _EMPTY_, + "test::@@@#22name" if enable_escape else _EMPTY_, + ], + "scope": [ + None, + "test_scope", + "test@scope" if enable_escape else _EMPTY_, + "test@@scope" if enable_escape else _EMPTY_, + "test:scope" if enable_escape else _EMPTY_, + "test:#:scope" if enable_escape else _EMPTY_, + ], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + var_str = build_variable_string( + { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + }, + enable_escape=enable_escape, + ) + + # Construct the expected output + expected = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + + test_cases.append((var_str, expected, enable_escape)) + + return test_cases + + +def test_parse_variables(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + print(f"input_str: {input_str}, expected_output: {expected_output}") + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed with escape" + + +def generate_build_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test\name" if enable_escape else _EMPTY_, + "test\\name" if enable_escape else _EMPTY_, + "test\:\#\@\%name" if enable_escape else _EMPTY_, + "test\::\###\@@\%%name" if enable_escape else _EMPTY_, + "test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_, + "test\:#:name" if enable_escape else _EMPTY_, + ], + "scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + + def escape_special_chars(s): + if not enable_escape or s is None: + return s + return ( + s.replace(":", "\\:") + .replace("@", "\\@") + .replace("%", "\\%") + .replace("#", "\\#") + ) + + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + # Construct the input dictionary + input_dict = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + input_dict_with_escape = { + k: escape_special_chars(v) for k, v in input_dict.items() + } + + # Construct the expected variable string + expected_str = "${test_key" + if name: + expected_str += f":{input_dict_with_escape['name']}" + if scope or scope_key: + expected_str += "@" + if scope: + expected_str += input_dict_with_escape["scope"] + if scope_key: + expected_str += f":{input_dict_with_escape['scope_key']}" + if sys_code: + expected_str += f"#{input_dict_with_escape['sys_code']}" + if user_name: + expected_str += f"%{input_dict_with_escape['user_name']}" + expected_str += "}" + + test_cases.append((input_dict, expected_str, enable_escape)) + + return test_cases + + +def test_build_variable_string(): + # Run test cases without escape + test_cases = generate_build_test_cases(enable_escape=False) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_build_test_cases(enable_escape=True) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + print(f"input_dict: {input_dict}, expected_str: {expected_str}") + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed with escape" + + +def test_variable_string_round_trip(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert ( + built_result == input_str + ), f"Round trip test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert built_result == input_str, f"Round trip test case {i} failed with escape" diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 8f99d1e30..7e308127c 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -182,36 +182,18 @@ class VariablesIdentifier(ResourceIdentifier): if not self.key or not self.name or not self.scope: raise ValueError("Key, name, and scope are required.") - if any( - self.identifier_split in key - for key in [ - self.key, - self.name, - self.scope, - self.scope_key, - self.sys_code, - self.user_name, - ] - if key is not None - ): - raise ValueError( - f"identifier_split {self.identifier_split} is not allowed in " - f"key, name, scope, scope_key, sys_code, user_name." - ) - @property def str_identifier(self) -> str: """Return the string identifier of the identifier.""" - return self.identifier_split.join( - key or "" - for key in [ - self.key, - self.name, - self.scope, - self.scope_key, - self.sys_code, - self.user_name, - ] + return build_variable_string( + { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } ) def to_dict(self) -> Dict: @@ -230,33 +212,30 @@ class VariablesIdentifier(ResourceIdentifier): } @classmethod - def from_str_identifier( - cls, str_identifier: str, identifier_split: str = "@" - ) -> "VariablesIdentifier": + def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier": """Create a VariablesIdentifier from a string identifier. Args: str_identifier (str): The string identifier. - identifier_split (str): The identifier split. Returns: VariablesIdentifier: The VariablesIdentifier. """ - keys = str_identifier.split(identifier_split) - if not keys: + variable_dict = parse_variable(str_identifier) + if not variable_dict: raise ValueError("Invalid string identifier.") - if len(keys) < 2: + if not variable_dict.get("key"): + raise ValueError("Invalid string identifier, must have key") + if not variable_dict.get("name"): raise ValueError("Invalid string identifier, must have name") - if len(keys) < 3: - raise ValueError("Invalid string identifier, must have scope") return cls( - key=keys[0], - name=keys[1], - scope=keys[2], - scope_key=keys[3] if len(keys) > 3 else None, - sys_code=keys[4] if len(keys) > 4 else None, - user_name=keys[5] if len(keys) > 5 else None, + key=variable_dict["key"], + name=variable_dict["name"], + scope=variable_dict.get("scope", "global"), + scope_key=variable_dict.get("scope_key"), + sys_code=variable_dict.get("sys_code"), + user_name=variable_dict.get("user_name"), ) @@ -402,6 +381,26 @@ class VariablesProvider(BaseComponent, ABC): """Whether the variables provider support async.""" return False + def _convert_to_value_type(self, var: StorageVariables): + """Convert the variable to the value type.""" + if var.value is None: + return None + if var.value_type == "str": + return str(var.value) + elif var.value_type == "int": + return int(var.value) + elif var.value_type == "float": + return float(var.value) + elif var.value_type == "bool": + if var.value.lower() in ["true", "1"]: + return True + elif var.value.lower() in ["false", "0"]: + return False + else: + return bool(var.value) + else: + return var.value + class VariablesPlaceHolder: """The variables place holder.""" @@ -410,46 +409,20 @@ class VariablesPlaceHolder: self, param_name: str, full_key: str, - value_type: str, default_value: Any = _EMPTY_DEFAULT_VALUE, ): """Initialize the variables place holder.""" self.param_name = param_name self.full_key = full_key - self.value_type = value_type self.default_value = default_value def parse(self, variables_provider: VariablesProvider) -> Any: """Parse the variables.""" - value = variables_provider.get(self.full_key, self.default_value) - if value: - return self._cast_to_type(value) - else: - return value - - def _cast_to_type(self, value: Any) -> Any: - if self.value_type == "str": - return str(value) - elif self.value_type == "int": - return int(value) - elif self.value_type == "float": - return float(value) - elif self.value_type == "bool": - if value.lower() in ["true", "1"]: - return True - elif value.lower() in ["false", "0"]: - return False - else: - return bool(value) - else: - return value + return variables_provider.get(self.full_key, self.default_value) def __repr__(self): """Return the representation of the variables place holder.""" - return ( - f"" - ) + return f"" class StorageVariablesProvider(VariablesProvider): @@ -493,7 +466,7 @@ class StorageVariablesProvider(VariablesProvider): and variable.salt ): variable.value = self.encryption.decrypt(variable.value, variable.salt) - return variable.value + return self._convert_to_value_type(variable) def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" @@ -676,3 +649,287 @@ class BuiltinVariablesProvider(VariablesProvider, ABC): def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" raise NotImplementedError("BuiltinVariablesProvider does not support save.") + + +def parse_variable( + variable_str: str, + enable_escape: bool = True, +) -> Dict[str, Any]: + """Parse the variable string. + + Examples: + .. code-block:: python + + cases = [ + { + "full_key": "${test_key:test_name@test_scope:test_scope_key}", + "expected": { + "key": "test_key", + "name": "test_name", + "scope": "test_scope", + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + { + "full_key": "${test_key#test_sys_code}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": None, + "sys_code": "test_sys_code", + "user_name": None, + }, + }, + { + "full_key": "${test_key@:test_scope_key}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + ] + for case in cases: + assert parse_variable(case["full_key"]) == case["expected"] + Args: + variable_str (str): The variable string. + enable_escape (bool): Whether to handle escaped characters. + Returns: + Dict[str, Any]: The parsed variable. + """ + if not variable_str.startswith("${") or not variable_str.endswith("}"): + raise ValueError( + "Invalid variable format, must start with '${' and end with '}'" + ) + + # Remove the surrounding ${ and } + content = variable_str[2:-1] + + # Define placeholders for escaped characters + placeholders = { + r"\@": "__ESCAPED_AT__", + r"\#": "__ESCAPED_HASH__", + r"\%": "__ESCAPED_PERCENT__", + r"\:": "__ESCAPED_COLON__", + } + + if enable_escape: + # Replace escaped characters with placeholders + for original, placeholder in placeholders.items(): + content = content.replace(original, placeholder) + + # Initialize the result dictionary + result: Dict[str, Optional[str]] = { + "key": None, + "name": None, + "scope": None, + "scope_key": None, + "sys_code": None, + "user_name": None, + } + + # Split the content by special characters + parts = content.split("@") + + # Parse key and name + key_name = parts[0].split("#")[0].split("%")[0] + if ":" in key_name: + result["key"], result["name"] = key_name.split(":", 1) + else: + result["key"] = key_name + + # Parse scope and scope_key + if len(parts) > 1: + scope_part = parts[1].split("#")[0].split("%")[0] + if ":" in scope_part: + result["scope"], result["scope_key"] = scope_part.split(":", 1) + else: + result["scope"] = scope_part + + # Parse sys_code + if "#" in content: + result["sys_code"] = content.split("#", 1)[1].split("%")[0] + + # Parse user_name + if "%" in content: + result["user_name"] = content.split("%", 1)[1] + + if enable_escape: + # Replace placeholders back with escaped characters + reverse_placeholders = {v: k[1:] for k, v in placeholders.items()} + for key, value in result.items(): + if value: + for placeholder, original in reverse_placeholders.items(): + result[key] = result[key].replace( # type: ignore + placeholder, original + ) + + # Replace empty strings with None + for key, value in result.items(): + if value == "": + result[key] = None + + return result + + +def _is_variable_format(value: str) -> bool: + if not value.startswith("${") or not value.endswith("}"): + return False + return True + + +def is_variable_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable string should start with "${" and end with "}", and contain key + and name + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if not variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def is_variable_list_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable list string should start with "${" and end with "}", and contain + key and not contain name + + A valid variable list string means that the variable is a list of variables with the + same key. + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def build_variable_string( + variable_dict: Dict[str, Any], + scope_sig: str = "@", + sys_code_sig: str = "#", + user_sig: str = "%", + kv_sig: str = ":", + enable_escape: bool = True, +) -> str: + """Build a variable string from the given dictionary. + + Args: + variable_dict (Dict[str, Any]): The dictionary containing the variable details. + scope_sig (str): The scope signature. + sys_code_sig (str): The sys code signature. + user_sig (str): The user signature. + kv_sig (str): The key-value split signature. + enable_escape (bool): Whether to escape special characters + + Returns: + str: The formatted variable string. + + Examples: + >>> build_variable_string( + ... { + ... "key": "test_key", + ... "name": "test_name", + ... "scope": "test_scope", + ... "scope_key": "test_scope_key", + ... "sys_code": "test_sys_code", + ... "user_name": "test_user", + ... } + ... ) + '${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}' + + >>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"}) + '${test_key@:test_scope_key}' + + >>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"}) + '${test_key#test_sys_code}' + + >>> build_variable_string({"key": "test_key"}) + '${test_key}' + """ + special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig} + # Replace None with "" + new_variable_dict = {key: value or "" for key, value in variable_dict.items()} + + # Check if the variable_dict contains any special characters + for key, value in new_variable_dict.items(): + if value != "" and any(char in value for char in special_chars): + if enable_escape: + # Escape special characters + new_variable_dict[key] = ( + value.replace("@", "\\@") + .replace("#", "\\#") + .replace("%", "\\%") + .replace(":", "\\:") + ) + else: + raise ValueError( + f"{key} contains special characters, error value: {value}, special " + f"characters: {special_chars}" + ) + + key = new_variable_dict.get("key", "") + name = new_variable_dict.get("name", "") + scope = new_variable_dict.get("scope", "") + scope_key = new_variable_dict.get("scope_key", "") + sys_code = new_variable_dict.get("sys_code", "") + user_name = new_variable_dict.get("user_name", "") + + # Construct the base of the variable string + variable_str = f"${{{key}" + + # Add name if present + if name: + variable_str += f":{name}" + + # Add scope and scope_key if present + if scope or scope_key: + variable_str += f"@{scope}" + if scope_key: + variable_str += f":{scope_key}" + + # Add sys_code if present + if sys_code: + variable_str += f"#{sys_code}" + + # Add user_name if present + if user_name: + variable_str += f"%{user_name}" + + # Close the variable string + variable_str += "}" + + return variable_str