mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def PublicAPI(*args, **kwargs):
|
||||
"""Decorator to mark a function or class as a public API.
|
||||
|
||||
@@ -64,7 +67,7 @@ def DeveloperAPI(*args, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _modify_docstring(obj, message: str = None):
|
||||
def _modify_docstring(obj, message: Optional[str] = None):
|
||||
if not message:
|
||||
return
|
||||
if not obj.__doc__:
|
||||
@@ -81,6 +84,7 @@ def _modify_docstring(obj, message: str = None):
|
||||
|
||||
if min_indent == float("inf"):
|
||||
min_indent = 0
|
||||
min_indent = int(min_indent)
|
||||
indented_message = message.rstrip() + "\n" + (" " * min_indent)
|
||||
obj.__doc__ = indented_message + original_doc
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
|
||||
class AppConfig:
|
||||
@@ -46,7 +46,7 @@ class AppConfig:
|
||||
"""
|
||||
env_lang = (
|
||||
"zh"
|
||||
if os.getenv("LANG") and os.getenv("LANG").startswith("zh")
|
||||
if os.getenv("LANG") and cast(str, os.getenv("LANG")).startswith("zh")
|
||||
else default
|
||||
)
|
||||
return self.get("dbgpt.app.global.language", env_lang)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Utilities for formatting strings."""
|
||||
import json
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
from typing import Any, List, Mapping, Sequence, Set, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
@@ -9,7 +9,7 @@ class StrictFormatter(Formatter):
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
used_args: Set[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
@@ -39,7 +39,7 @@ class StrictFormatter(Formatter):
|
||||
class NoStrictFormatter(StrictFormatter):
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
used_args: Set[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
|
@@ -12,14 +12,14 @@ MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_class: str
|
||||
param_name: str
|
||||
param_type: str
|
||||
default_value: Optional[Any]
|
||||
description: str
|
||||
required: Optional[bool]
|
||||
valid_values: Optional[List[Any]]
|
||||
ext_metadata: Dict
|
||||
required: bool = False
|
||||
param_class: Optional[str] = None
|
||||
param_name: Optional[str] = None
|
||||
param_type: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
default_value: Optional[Any] = None
|
||||
valid_values: Optional[List[Any]] = None
|
||||
ext_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -186,7 +186,9 @@ def _get_simple_privacy_field_value(obj, field_info):
|
||||
return "******"
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
def _genenv_ignoring_key_case(
|
||||
env_key: str, env_prefix: Optional[str] = None, default_value: Optional[str] = None
|
||||
):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
env_key = env_prefix + env_key
|
||||
@@ -196,7 +198,9 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case_with_prefixes(
|
||||
env_key: str, env_prefixes: List[str] = None, default_value=None
|
||||
env_key: str,
|
||||
env_prefixes: Optional[List[str]] = None,
|
||||
default_value: Optional[str] = None,
|
||||
) -> str:
|
||||
if env_prefixes:
|
||||
for env_prefix in env_prefixes:
|
||||
@@ -208,7 +212,7 @@ def _genenv_ignoring_key_case_with_prefixes(
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
def get_env_prefix(env_key: str) -> Optional[str]:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
@@ -217,14 +221,14 @@ class EnvArgumentParser:
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefixes: List[str] = None,
|
||||
command_args: List[str] = None,
|
||||
env_prefixes: Optional[List[str]] = None,
|
||||
command_args: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case_with_prefixes(
|
||||
env_var_value: Any = _genenv_ignoring_key_case_with_prefixes(
|
||||
field.name, env_prefixes
|
||||
)
|
||||
if env_var_value:
|
||||
@@ -313,7 +317,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def create_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
):
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
@@ -322,8 +327,9 @@ class EnvArgumentParser:
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
dataclass_types = list(_types) # type: ignore
|
||||
for dataclass_type in dataclass_types:
|
||||
# type: ignore
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
@@ -345,7 +351,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def _create_raw_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
):
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
@@ -362,7 +369,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def create_argparse_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
) -> argparse.ArgumentParser:
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
@@ -429,7 +437,7 @@ class EnvArgumentParser:
|
||||
return "str"
|
||||
|
||||
@staticmethod
|
||||
def _is_require_type(field_type: Type) -> str:
|
||||
def _is_require_type(field_type: Type) -> bool:
|
||||
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||
|
||||
@staticmethod
|
||||
@@ -455,13 +463,13 @@ class EnvArgumentParser:
|
||||
|
||||
|
||||
def _merge_dataclass_types(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type, _dynamic_factory: Optional[Callable[[], List[Type]]] = None
|
||||
) -> OrderedDict:
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
dataclass_types = list(_types) # type: ignore
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
@@ -511,11 +519,12 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
if not desc:
|
||||
raise ValueError("Parameter descriptions cant be empty")
|
||||
param_class_str = desc[0].param_class
|
||||
class_name = None
|
||||
if param_class_str:
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
|
||||
fields_dict = {} # This will store field names and their default values or field()
|
||||
annotations = {} # This will store the type annotations for the fields
|
||||
@@ -526,25 +535,30 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
metadata["valid_values"] = d.valid_values
|
||||
|
||||
annotations[d.param_name] = _type_str_to_python_type(
|
||||
d.param_type
|
||||
d.param_type # type: ignore
|
||||
) # Set type annotation
|
||||
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
||||
|
||||
# Create the new class. Note the setting of __annotations__ for type hints
|
||||
new_class = type(
|
||||
class_name, (object,), {**fields_dict, "__annotations__": annotations}
|
||||
class_name, # type: ignore
|
||||
(object,),
|
||||
{**fields_dict, "__annotations__": annotations}, # type: ignore
|
||||
)
|
||||
result_class = dataclass(new_class) # Make it a dataclass
|
||||
# Make it a dataclass
|
||||
result_class = dataclass(new_class) # type: ignore
|
||||
|
||||
return result_class
|
||||
|
||||
|
||||
def _extract_parameter_details(
|
||||
parser: argparse.ArgumentParser,
|
||||
param_class: str = None,
|
||||
skip_names: List[str] = None,
|
||||
overwrite_default_values: Dict = {},
|
||||
param_class: Optional[str] = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
overwrite_default_values: Optional[Dict[str, Any]] = None,
|
||||
) -> List[ParameterDescription]:
|
||||
if overwrite_default_values is None:
|
||||
overwrite_default_values = {}
|
||||
descriptions = []
|
||||
|
||||
for action in parser._actions:
|
||||
@@ -575,7 +589,9 @@ def _extract_parameter_details(
|
||||
if param_name in overwrite_default_values:
|
||||
default_value = overwrite_default_values[param_name]
|
||||
arg_type = (
|
||||
action.type if not callable(action.type) else str(action.type.__name__)
|
||||
action.type
|
||||
if not callable(action.type)
|
||||
else str(action.type.__name__) # type: ignore
|
||||
)
|
||||
description = action.help
|
||||
|
||||
@@ -583,10 +599,10 @@ def _extract_parameter_details(
|
||||
required = action.required
|
||||
|
||||
# extract valid values for choices, if provided
|
||||
valid_values = action.choices if action.choices is not None else None
|
||||
valid_values = list(action.choices) if action.choices is not None else None
|
||||
|
||||
# set ext_metadata as an empty dict for now, can be updated later if needed
|
||||
ext_metadata = {}
|
||||
ext_metadata: Dict[str, Any] = {}
|
||||
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
@@ -621,7 +637,7 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
|
||||
from dbgpt._private import pydantic
|
||||
|
||||
version = int(pydantic.VERSION.split(".")[0])
|
||||
version = int(pydantic.VERSION.split(".")[0]) # type: ignore
|
||||
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
|
||||
required_fields = set(schema.get("required", []))
|
||||
param_descs = []
|
||||
@@ -661,7 +677,7 @@ def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescri
|
||||
ext_metadata = (
|
||||
field.field_info.extra if hasattr(field.field_info, "extra") else None
|
||||
)
|
||||
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
|
||||
param_class = f"{model_cls.__module__}.{model_cls.__name__}"
|
||||
param_desc = ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=field_name,
|
||||
|
@@ -5,7 +5,7 @@ import asyncio
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
@@ -28,19 +28,25 @@ def _get_logging_level() -> str:
|
||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def setup_logging_level(logging_level=None, logger_name: str = None):
|
||||
def setup_logging_level(
|
||||
logging_level: Optional[str] = None, logger_name: Optional[str] = None
|
||||
):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
if type(logging_level) is str:
|
||||
logging_level = logging.getLevelName(logging_level.upper())
|
||||
if logger_name:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging_level)
|
||||
logger.setLevel(cast(str, logging_level))
|
||||
else:
|
||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||
|
||||
|
||||
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
|
||||
def setup_logging(
|
||||
logger_name: str,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
logger = _build_logger(logger_name, logging_level, logger_filename)
|
||||
@@ -74,7 +80,11 @@ def get_gpu_memory(max_gpus=None):
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
def _build_logger(
|
||||
logger_name,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
@@ -111,14 +121,14 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
assert loop is not None
|
||||
return loop
|
||||
return cast(asyncio.BaseEventLoop, loop)
|
||||
except RuntimeError as e:
|
||||
if not "no running event loop" in str(e) and not "no current event loop" in str(
|
||||
e
|
||||
):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
return asyncio.get_event_loop_policy().new_event_loop()
|
||||
return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop())
|
||||
|
||||
|
||||
def logging_str_to_uvicorn_level(log_level_str):
|
||||
@@ -152,7 +162,7 @@ class EndpointFilter(logging.Filter):
|
||||
return record.getMessage().find(self._path) == -1
|
||||
|
||||
|
||||
def setup_http_service_logging(exclude_paths: List[str] = None):
|
||||
def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
|
||||
"""Setup http service logging
|
||||
|
||||
Now just disable some logs
|
||||
|
Reference in New Issue
Block a user