mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-24 11:00:17 +00:00
863 lines
31 KiB
Python
863 lines
31 KiB
Python
import argparse
|
|
import os
|
|
from collections import OrderedDict
|
|
from dataclasses import MISSING, asdict, dataclass, field, fields, is_dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
|
|
|
if TYPE_CHECKING:
|
|
from dbgpt._private.pydantic import BaseModel
|
|
|
|
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
|
|
|
|
|
|
@dataclass
|
|
class ParameterDescription:
|
|
required: bool = False
|
|
param_class: Optional[str] = None
|
|
param_name: Optional[str] = None
|
|
param_type: Optional[str] = None
|
|
description: Optional[str] = None
|
|
default_value: Optional[Any] = None
|
|
valid_values: Optional[List[Any]] = None
|
|
ext_metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@dataclass
|
|
class BaseParameters:
|
|
@classmethod
|
|
def from_dict(
|
|
cls, data: dict, ignore_extra_fields: bool = False
|
|
) -> "BaseParameters":
|
|
"""Create an instance of the dataclass from a dictionary.
|
|
|
|
Args:
|
|
data: A dictionary containing values for the dataclass fields.
|
|
ignore_extra_fields: If True, any extra fields in the data dictionary that are
|
|
not part of the dataclass will be ignored.
|
|
If False, extra fields will raise an error. Defaults to False.
|
|
Returns:
|
|
An instance of the dataclass with values populated from the given dictionary.
|
|
|
|
Raises:
|
|
TypeError: If `ignore_extra_fields` is False and there are fields in the
|
|
dictionary that aren't present in the dataclass.
|
|
"""
|
|
all_field_names = {f.name for f in fields(cls)}
|
|
if ignore_extra_fields:
|
|
data = {key: value for key, value in data.items() if key in all_field_names}
|
|
else:
|
|
extra_fields = set(data.keys()) - all_field_names
|
|
if extra_fields:
|
|
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
|
|
return cls(**data)
|
|
|
|
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
|
"""
|
|
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
|
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
|
|
|
Args:
|
|
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
|
|
|
Returns:
|
|
bool: True if at least one field was updated, otherwise False.
|
|
"""
|
|
updated = False # Flag to indicate whether any field was updated
|
|
if isinstance(source, (BaseParameters, dict)):
|
|
for field_info in fields(self):
|
|
# Check if the field has a "fixed" tag in metadata
|
|
tags = field_info.metadata.get("tags")
|
|
tags = [] if not tags else tags.split(",")
|
|
if tags and "fixed" in tags:
|
|
continue # skip this field
|
|
# Get the new value from source (either another BaseParameters object or a dict)
|
|
new_value = (
|
|
getattr(source, field_info.name)
|
|
if isinstance(source, BaseParameters)
|
|
else source.get(field_info.name, None)
|
|
)
|
|
|
|
# If the new value is not None and different from the current value, update the field and set the flag
|
|
if new_value is not None and new_value != getattr(
|
|
self, field_info.name
|
|
):
|
|
setattr(self, field_info.name, new_value)
|
|
updated = True
|
|
else:
|
|
raise ValueError(
|
|
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
|
)
|
|
|
|
return updated
|
|
|
|
def __str__(self) -> str:
|
|
return _get_dataclass_print_str(self)
|
|
|
|
def to_command_args(self, args_prefix: str = "--") -> List[str]:
|
|
"""Convert the fields of the dataclass to a list of command line arguments.
|
|
|
|
Args:
|
|
args_prefix: args prefix
|
|
Returns:
|
|
A list of strings where each field is represented by two items:
|
|
one for the field name prefixed by args_prefix, and one for its value.
|
|
"""
|
|
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class BaseServerParameters(BaseParameters):
|
|
host: Optional[str] = field(
|
|
default="0.0.0.0", metadata={"help": "The host IP address to bind to."}
|
|
)
|
|
port: Optional[int] = field(
|
|
default=None, metadata={"help": "The port number to bind to."}
|
|
)
|
|
daemon: Optional[bool] = field(
|
|
default=False, metadata={"help": "Run the server as a daemon."}
|
|
)
|
|
log_level: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Logging level",
|
|
"valid_values": [
|
|
"FATAL",
|
|
"ERROR",
|
|
"WARNING",
|
|
"WARNING",
|
|
"INFO",
|
|
"DEBUG",
|
|
"NOTSET",
|
|
],
|
|
},
|
|
)
|
|
log_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The filename to store log",
|
|
},
|
|
)
|
|
tracer_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The filename to store tracer span records",
|
|
},
|
|
)
|
|
tracer_to_open_telemetry: Optional[bool] = field(
|
|
default=os.getenv("TRACER_TO_OPEN_TELEMETRY", "False").lower() == "true",
|
|
metadata={
|
|
"help": "Whether send tracer span records to OpenTelemetry",
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_endpoint: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "`OTEL_EXPORTER_OTLP_TRACES_ENDPOINT` target to which the span "
|
|
"exporter is going to send spans. The endpoint MUST be a valid URL host, "
|
|
"and MAY contain a scheme (http or https), port and path. A scheme of https"
|
|
" indicates a secure connection and takes precedence over this "
|
|
"configuration setting.",
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_insecure: Optional[bool] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "OTEL_EXPORTER_OTLP_TRACES_INSECURE` represents whether to enable "
|
|
"client transport security for gRPC requests for spans. A scheme of https "
|
|
"takes precedence over the this configuration setting. Default: False"
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_certificate: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "`OTEL_EXPORTER_OTLP_TRACES_CERTIFICATE` stores the path to the "
|
|
"certificate file for TLS credentials of gRPC client for traces. "
|
|
"Should only be used for a secure connection for tracing",
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_headers: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "`OTEL_EXPORTER_OTLP_TRACES_HEADERS` contains the key-value pairs "
|
|
"to be used as headers for spans associated with gRPC or HTTP requests.",
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_timeout: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "`OTEL_EXPORTER_OTLP_TRACES_TIMEOUT` is the maximum time the OTLP "
|
|
"exporter will wait for each batch export for spans.",
|
|
},
|
|
)
|
|
otel_exporter_otlp_traces_compression: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "`OTEL_EXPORTER_OTLP_COMPRESSION` but only for the span exporter. "
|
|
"If both are present, this takes higher precedence.",
|
|
},
|
|
)
|
|
|
|
|
|
def _get_dataclass_print_str(obj):
|
|
class_name = obj.__class__.__name__
|
|
parameters = [
|
|
f"\n\n=========================== {class_name} ===========================\n"
|
|
]
|
|
for field_info in fields(obj):
|
|
value = _get_simple_privacy_field_value(obj, field_info)
|
|
parameters.append(f"{field_info.name}: {value}")
|
|
parameters.append(
|
|
"\n======================================================================\n\n"
|
|
)
|
|
return "\n".join(parameters)
|
|
|
|
|
|
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
|
|
"""Convert dict to a list of command line arguments
|
|
|
|
Args:
|
|
obj: dict
|
|
Returns:
|
|
A list of strings where each field is represented by two items:
|
|
one for the field name prefixed by args_prefix, and one for its value.
|
|
"""
|
|
args = []
|
|
for key, value in obj.items():
|
|
if value is None:
|
|
continue
|
|
args.append(f"{args_prefix}{key}")
|
|
args.append(str(value))
|
|
return args
|
|
|
|
|
|
def _get_simple_privacy_field_value(obj, field_info):
|
|
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
|
|
|
This function reads the metadata of a field to check if it's tagged with 'privacy'.
|
|
If the 'privacy' tag is present, then it modifies the value based on its type
|
|
for privacy concerns:
|
|
- int: returns -999
|
|
- float: returns -999.0
|
|
- bool: returns False
|
|
- str: if length > 5, masks the middle part and returns first and last char;
|
|
otherwise, returns "******"
|
|
|
|
Args:
|
|
obj: The dataclass instance.
|
|
field_info: A Field object that contains information about the dataclass field.
|
|
|
|
Returns:
|
|
The original or modified value of the field based on the privacy rules.
|
|
|
|
Example usage:
|
|
@dataclass
|
|
class Person:
|
|
name: str
|
|
age: int
|
|
ssn: str = field(metadata={"tags": "privacy"})
|
|
p = Person("Alice", 30, "123-45-6789")
|
|
print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9
|
|
"""
|
|
tags = field_info.metadata.get("tags")
|
|
tags = [] if not tags else tags.split(",")
|
|
is_privacy = False
|
|
if tags and "privacy" in tags:
|
|
is_privacy = True
|
|
value = getattr(obj, field_info.name)
|
|
if not is_privacy or not value:
|
|
return value
|
|
field_type = EnvArgumentParser._get_argparse_type(field_info.type)
|
|
if field_type is int:
|
|
return -999
|
|
if field_type is float:
|
|
return -999.0
|
|
if field_type is bool:
|
|
return False
|
|
# str
|
|
if len(value) > 5:
|
|
return value[0] + "******" + value[-1]
|
|
return "******"
|
|
|
|
|
|
def _genenv_ignoring_key_case(
|
|
env_key: str, env_prefix: Optional[str] = None, default_value: Optional[str] = None
|
|
):
|
|
"""Get the value from the environment variable, ignoring the case of the key"""
|
|
if env_prefix:
|
|
env_key = env_prefix + env_key
|
|
return os.getenv(
|
|
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
|
)
|
|
|
|
|
|
def _genenv_ignoring_key_case_with_prefixes(
|
|
env_key: str,
|
|
env_prefixes: Optional[List[str]] = None,
|
|
default_value: Optional[str] = None,
|
|
) -> str:
|
|
if env_prefixes:
|
|
for env_prefix in env_prefixes:
|
|
env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
|
|
if env_var_value:
|
|
return env_var_value
|
|
return _genenv_ignoring_key_case(env_key, default_value=default_value)
|
|
|
|
|
|
class EnvArgumentParser:
|
|
@staticmethod
|
|
def get_env_prefix(env_key: str) -> Optional[str]:
|
|
if not env_key:
|
|
return None
|
|
env_key = env_key.replace("-", "_")
|
|
return env_key + "_"
|
|
|
|
def parse_args_into_dataclass(
|
|
self,
|
|
dataclass_type: Type,
|
|
env_prefixes: Optional[List[str]] = None,
|
|
command_args: Optional[List[str]] = None,
|
|
**kwargs,
|
|
) -> Any:
|
|
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
for field in fields(dataclass_type):
|
|
env_var_value: Any = _genenv_ignoring_key_case_with_prefixes(
|
|
field.name, env_prefixes
|
|
)
|
|
if env_var_value:
|
|
env_var_value = env_var_value.strip()
|
|
if field.type is int or field.type == Optional[int]:
|
|
env_var_value = int(env_var_value)
|
|
elif field.type is float or field.type == Optional[float]:
|
|
env_var_value = float(env_var_value)
|
|
elif field.type is bool or field.type == Optional[bool]:
|
|
env_var_value = env_var_value.lower() == "true"
|
|
elif field.type is str or field.type == Optional[str]:
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unsupported parameter type {field.type}")
|
|
if not env_var_value:
|
|
env_var_value = kwargs.get(field.name)
|
|
|
|
# print(f"env_var_value: {env_var_value} for {field.name}")
|
|
# Add a command-line argument for this field
|
|
EnvArgumentParser._build_single_argparse_option(
|
|
parser, field, env_var_value
|
|
)
|
|
|
|
# Parse the command-line arguments
|
|
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
|
# cmd_args = parser.parse_args(args=command_args)
|
|
# print(f"cmd_args: {cmd_args}")
|
|
for field in fields(dataclass_type):
|
|
# cmd_line_value = getattr(cmd_args, field.name)
|
|
if field.name in cmd_args:
|
|
cmd_line_value = getattr(cmd_args, field.name)
|
|
if cmd_line_value is not None:
|
|
kwargs[field.name] = cmd_line_value
|
|
|
|
return dataclass_type(**kwargs)
|
|
|
|
@staticmethod
|
|
def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
|
|
for field in fields(dataclass_type):
|
|
help_text = field.metadata.get("help", "")
|
|
valid_values = field.metadata.get("valid_values", None)
|
|
argument_kwargs = {
|
|
"type": EnvArgumentParser._get_argparse_type(field.type),
|
|
"help": help_text,
|
|
"choices": valid_values,
|
|
"required": EnvArgumentParser._is_require_type(field.type),
|
|
}
|
|
if field.default != MISSING:
|
|
argument_kwargs["default"] = field.default
|
|
argument_kwargs["required"] = False
|
|
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
|
return parser
|
|
|
|
@staticmethod
|
|
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
|
|
import click
|
|
|
|
help_text = field.metadata.get("help", "")
|
|
valid_values = field.metadata.get("valid_values", None)
|
|
cli_params = {
|
|
"default": None if field.default is MISSING else field.default,
|
|
"help": help_text,
|
|
"show_default": True,
|
|
"required": field.default is MISSING,
|
|
}
|
|
if valid_values:
|
|
cli_params["type"] = click.Choice(valid_values)
|
|
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
|
if real_type is int:
|
|
cli_params["type"] = click.INT
|
|
elif real_type is float:
|
|
cli_params["type"] = click.FLOAT
|
|
elif real_type is str:
|
|
cli_params["type"] = click.STRING
|
|
elif real_type is bool:
|
|
cli_params["is_flag"] = True
|
|
name = f"--{field_name}"
|
|
if is_func:
|
|
return click.option(
|
|
name,
|
|
**cli_params,
|
|
)
|
|
else:
|
|
return click.Option([name], **cli_params)
|
|
|
|
@staticmethod
|
|
def create_click_option(
|
|
*dataclass_types: Type,
|
|
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
|
):
|
|
import functools
|
|
from collections import OrderedDict
|
|
|
|
combined_fields = OrderedDict()
|
|
if _dynamic_factory:
|
|
_types = _dynamic_factory()
|
|
if _types:
|
|
dataclass_types = list(_types) # type: ignore
|
|
for dataclass_type in dataclass_types:
|
|
# type: ignore
|
|
for field in fields(dataclass_type):
|
|
if field.name not in combined_fields:
|
|
combined_fields[field.name] = field
|
|
|
|
def decorator(func):
|
|
for field_name, field in reversed(combined_fields.items()):
|
|
option_decorator = EnvArgumentParser._create_click_option_from_field(
|
|
field_name, field
|
|
)
|
|
func = option_decorator(func)
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
@staticmethod
|
|
def _create_raw_click_option(
|
|
*dataclass_types: Type,
|
|
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
|
):
|
|
combined_fields = _merge_dataclass_types(
|
|
*dataclass_types, _dynamic_factory=_dynamic_factory
|
|
)
|
|
options = []
|
|
|
|
for field_name, field in reversed(combined_fields.items()):
|
|
options.append(
|
|
EnvArgumentParser._create_click_option_from_field(
|
|
field_name, field, is_func=False
|
|
)
|
|
)
|
|
return options
|
|
|
|
@staticmethod
|
|
def create_argparse_option(
|
|
*dataclass_types: Type,
|
|
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
|
) -> argparse.ArgumentParser:
|
|
combined_fields = _merge_dataclass_types(
|
|
*dataclass_types, _dynamic_factory=_dynamic_factory
|
|
)
|
|
parser = argparse.ArgumentParser()
|
|
for _, field in reversed(combined_fields.items()):
|
|
EnvArgumentParser._build_single_argparse_option(parser, field)
|
|
return parser
|
|
|
|
@staticmethod
|
|
def _build_single_argparse_option(
|
|
parser: argparse.ArgumentParser, field, default_value=None
|
|
):
|
|
# Add a command-line argument for this field
|
|
help_text = field.metadata.get("help", "")
|
|
valid_values = field.metadata.get("valid_values", None)
|
|
short_name = field.metadata.get("short", None)
|
|
argument_kwargs = {
|
|
"type": EnvArgumentParser._get_argparse_type(field.type),
|
|
"help": help_text,
|
|
"choices": valid_values,
|
|
"required": EnvArgumentParser._is_require_type(field.type),
|
|
}
|
|
if field.default != MISSING:
|
|
argument_kwargs["default"] = field.default
|
|
argument_kwargs["required"] = False
|
|
if default_value:
|
|
argument_kwargs["default"] = default_value
|
|
argument_kwargs["required"] = False
|
|
if field.type is bool or field.type == Optional[bool]:
|
|
argument_kwargs["action"] = "store_true"
|
|
del argument_kwargs["type"]
|
|
del argument_kwargs["choices"]
|
|
names = []
|
|
if short_name:
|
|
names.append(f"-{short_name}")
|
|
names.append(f"--{field.name}")
|
|
parser.add_argument(*names, **argument_kwargs)
|
|
|
|
@staticmethod
|
|
def _get_argparse_type(field_type: Type) -> Type:
|
|
# Return the appropriate type for argparse to use based on the field type
|
|
if field_type is int or field_type == Optional[int]:
|
|
return int
|
|
elif field_type is float or field_type == Optional[float]:
|
|
return float
|
|
elif field_type is bool or field_type == Optional[bool]:
|
|
return bool
|
|
elif field_type is str or field_type == Optional[str]:
|
|
return str
|
|
elif field_type is dict or field_type == Optional[dict]:
|
|
return dict
|
|
else:
|
|
raise ValueError(f"Unsupported parameter type {field_type}")
|
|
|
|
@staticmethod
|
|
def _get_argparse_type_str(field_type: Type) -> str:
|
|
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
|
if argparse_type is int:
|
|
return "int"
|
|
elif argparse_type is float:
|
|
return "float"
|
|
elif argparse_type is bool:
|
|
return "bool"
|
|
elif argparse_type is dict:
|
|
return "dict"
|
|
else:
|
|
return "str"
|
|
|
|
@staticmethod
|
|
def _is_require_type(field_type: Type) -> bool:
|
|
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
|
|
|
@staticmethod
|
|
def _kwargs_to_env_key_value(
|
|
kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__"
|
|
) -> Dict[str, str]:
|
|
return {prefix + k: str(v) for k, v in kwargs.items()}
|
|
|
|
@staticmethod
|
|
def _read_env_key_value(
|
|
prefix: str = "__dbgpt_gunicorn__env_prefix__",
|
|
) -> List[str]:
|
|
env_args = []
|
|
for key, value in os.environ.items():
|
|
if key.startswith(prefix):
|
|
arg_key = "--" + key.replace(prefix, "")
|
|
if value.lower() in ["true", "1"]:
|
|
# Flag args
|
|
env_args.append(arg_key)
|
|
elif not value.lower() in ["false", "0"]:
|
|
env_args.extend([arg_key, value])
|
|
return env_args
|
|
|
|
|
|
def _merge_dataclass_types(
|
|
*dataclass_types: Type, _dynamic_factory: Optional[Callable[[], List[Type]]] = None
|
|
) -> OrderedDict:
|
|
combined_fields = OrderedDict()
|
|
if _dynamic_factory:
|
|
_types = _dynamic_factory()
|
|
if _types:
|
|
dataclass_types = list(_types) # type: ignore
|
|
for dataclass_type in dataclass_types:
|
|
for field in fields(dataclass_type):
|
|
if field.name not in combined_fields:
|
|
combined_fields[field.name] = field
|
|
return combined_fields
|
|
|
|
|
|
def _type_str_to_python_type(type_str: str) -> Type:
|
|
type_mapping: Dict[str, Type] = {
|
|
"int": int,
|
|
"float": float,
|
|
"bool": bool,
|
|
"str": str,
|
|
}
|
|
return type_mapping.get(type_str, str)
|
|
|
|
|
|
def _get_parameter_descriptions(
|
|
dataclass_type: Type, **kwargs
|
|
) -> List[ParameterDescription]:
|
|
descriptions = []
|
|
for field in fields(dataclass_type):
|
|
ext_metadata = {
|
|
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
|
}
|
|
default_value = field.default if field.default != MISSING else None
|
|
if field.name in kwargs:
|
|
default_value = kwargs[field.name]
|
|
descriptions.append(
|
|
ParameterDescription(
|
|
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
|
param_name=field.name,
|
|
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
|
description=field.metadata.get("help", None),
|
|
required=field.default is MISSING,
|
|
default_value=default_value,
|
|
valid_values=field.metadata.get("valid_values", None),
|
|
ext_metadata=ext_metadata,
|
|
)
|
|
)
|
|
return descriptions
|
|
|
|
|
|
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
|
from dbgpt.util.module_utils import import_from_string
|
|
|
|
if not desc:
|
|
raise ValueError("Parameter descriptions cant be empty")
|
|
param_class_str = desc[0].param_class
|
|
class_name = None
|
|
if param_class_str:
|
|
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
|
if param_class:
|
|
return param_class
|
|
module_name, _, class_name = param_class_str.rpartition(".")
|
|
|
|
fields_dict = {} # This will store field names and their default values or field()
|
|
annotations = {} # This will store the type annotations for the fields
|
|
|
|
for d in desc:
|
|
metadata = d.ext_metadata if d.ext_metadata else {}
|
|
metadata["help"] = d.description
|
|
metadata["valid_values"] = d.valid_values
|
|
|
|
annotations[d.param_name] = _type_str_to_python_type(
|
|
d.param_type # type: ignore
|
|
) # Set type annotation
|
|
# fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
|
if d.param_name == "ignore_patterns":
|
|
fields_dict[d.param_name] = field(default=None, metadata=metadata)
|
|
else:
|
|
fields_dict[d.param_name] = field(
|
|
default=d.default_value, metadata=metadata
|
|
)
|
|
|
|
# Create the new class. Note the setting of __annotations__ for type hints
|
|
new_class = type(
|
|
class_name, # type: ignore
|
|
(object,),
|
|
{**fields_dict, "__annotations__": annotations}, # type: ignore
|
|
)
|
|
# Make it a dataclass
|
|
result_class = dataclass(new_class) # type: ignore
|
|
|
|
return result_class
|
|
|
|
|
|
def _extract_parameter_details(
|
|
parser: argparse.ArgumentParser,
|
|
param_class: Optional[str] = None,
|
|
skip_names: Optional[List[str]] = None,
|
|
overwrite_default_values: Optional[Dict[str, Any]] = None,
|
|
) -> List[ParameterDescription]:
|
|
if overwrite_default_values is None:
|
|
overwrite_default_values = {}
|
|
descriptions = []
|
|
|
|
for action in parser._actions:
|
|
if (
|
|
action.default == argparse.SUPPRESS
|
|
): # typically this means the argument was not provided
|
|
continue
|
|
|
|
# determine parameter class (store_true/store_false are flags)
|
|
flag_or_option = (
|
|
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
|
|
)
|
|
|
|
# extract parameter name (use the first option string, typically the long form)
|
|
param_name = action.option_strings[0] if action.option_strings else action.dest
|
|
if param_name.startswith("--"):
|
|
param_name = param_name[2:]
|
|
if param_name.startswith("-"):
|
|
param_name = param_name[1:]
|
|
|
|
param_name = param_name.replace("-", "_")
|
|
|
|
if skip_names and param_name in skip_names:
|
|
continue
|
|
|
|
# gather other details
|
|
default_value = action.default
|
|
if param_name in overwrite_default_values:
|
|
default_value = overwrite_default_values[param_name]
|
|
arg_type = (
|
|
action.type
|
|
if not callable(action.type)
|
|
else str(action.type.__name__) # type: ignore
|
|
)
|
|
description = action.help
|
|
|
|
# determine if the argument is required
|
|
required = action.required
|
|
|
|
# extract valid values for choices, if provided
|
|
valid_values = list(action.choices) if action.choices is not None else None
|
|
|
|
# set ext_metadata as an empty dict for now, can be updated later if needed
|
|
ext_metadata: Dict[str, Any] = {}
|
|
|
|
descriptions.append(
|
|
ParameterDescription(
|
|
param_class=param_class,
|
|
param_name=param_name,
|
|
param_type=arg_type,
|
|
default_value=default_value,
|
|
description=description,
|
|
required=required,
|
|
valid_values=valid_values,
|
|
ext_metadata=ext_metadata,
|
|
)
|
|
)
|
|
|
|
return descriptions
|
|
|
|
|
|
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
|
if not obj:
|
|
return None
|
|
if is_dataclass(type(obj)):
|
|
params = {}
|
|
for field_info in fields(obj):
|
|
value = _get_simple_privacy_field_value(obj, field_info)
|
|
params[field_info.name] = value
|
|
return params
|
|
if isinstance(obj, dict):
|
|
return obj
|
|
return default_value
|
|
|
|
|
|
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
|
|
from dbgpt._private import pydantic
|
|
|
|
version = int(pydantic.VERSION.split(".")[0]) # type: ignore
|
|
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
|
|
required_fields = set(schema.get("required", []))
|
|
param_descs = []
|
|
for field_name, field_schema in schema.get("properties", {}).items():
|
|
field = model_cls.model_fields[field_name]
|
|
param_type = field_schema.get("type")
|
|
if not param_type and "anyOf" in field_schema:
|
|
for any_of in field_schema["anyOf"]:
|
|
if any_of["type"] != "null":
|
|
param_type = any_of["type"]
|
|
break
|
|
if version >= 2:
|
|
default_value = (
|
|
field.default
|
|
if hasattr(field, "default")
|
|
and str(field.default) != "PydanticUndefined"
|
|
else None
|
|
)
|
|
else:
|
|
default_value = (
|
|
field.default
|
|
if not field.allow_none
|
|
else (
|
|
field.default_factory() if callable(field.default_factory) else None
|
|
)
|
|
)
|
|
description = field_schema.get("description", "")
|
|
is_required = field_name in required_fields
|
|
valid_values = None
|
|
ext_metadata = None
|
|
if hasattr(field, "field_info"):
|
|
valid_values = (
|
|
list(field.field_info.choices)
|
|
if hasattr(field.field_info, "choices")
|
|
else None
|
|
)
|
|
ext_metadata = (
|
|
field.field_info.extra if hasattr(field.field_info, "extra") else None
|
|
)
|
|
param_class = f"{model_cls.__module__}.{model_cls.__name__}"
|
|
param_desc = ParameterDescription(
|
|
param_class=param_class,
|
|
param_name=field_name,
|
|
param_type=param_type,
|
|
default_value=default_value,
|
|
description=description,
|
|
required=is_required,
|
|
valid_values=valid_values,
|
|
ext_metadata=ext_metadata,
|
|
)
|
|
param_descs.append(param_desc)
|
|
return param_descs
|
|
|
|
|
|
class _SimpleArgParser:
|
|
def __init__(self, *args):
|
|
self.params = {arg.replace("_", "-"): None for arg in args}
|
|
|
|
def parse(self, args=None):
|
|
import sys
|
|
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
else:
|
|
args = list(args)
|
|
prev_arg = None
|
|
for arg in args:
|
|
if arg.startswith("--"):
|
|
if prev_arg:
|
|
self.params[prev_arg] = None
|
|
prev_arg = arg[2:]
|
|
else:
|
|
if prev_arg:
|
|
self.params[prev_arg] = arg
|
|
prev_arg = None
|
|
|
|
if prev_arg:
|
|
self.params[prev_arg] = None
|
|
|
|
def _get_param(self, key):
|
|
return self.params.get(key.replace("_", "-")) or self.params.get(key)
|
|
|
|
def __getattr__(self, item):
|
|
return self._get_param(item)
|
|
|
|
def __getitem__(self, key):
|
|
return self._get_param(key)
|
|
|
|
def get(self, key, default=None):
|
|
return self._get_param(key) or default
|
|
|
|
def __str__(self):
|
|
return "\n".join(
|
|
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
|
)
|
|
|
|
|
|
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
|
|
import click
|
|
|
|
class LazyCommand(click.Command):
|
|
def __init__(self, *args, **kwargs):
|
|
super(LazyCommand, self).__init__(*args, **kwargs)
|
|
self.dynamic_params_added = False
|
|
|
|
def get_params(self, ctx):
|
|
if ctx and not self.dynamic_params_added:
|
|
dynamic_params = EnvArgumentParser._create_raw_click_option(
|
|
*dataclass_types, _dynamic_factory=_dynamic_factory
|
|
)
|
|
for param in reversed(dynamic_params):
|
|
self.params.append(param)
|
|
self.dynamic_params_added = True
|
|
return super(LazyCommand, self).get_params(ctx)
|
|
|
|
return LazyCommand
|