mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .builder import build_from_config, build_from_registry, build_gradient_handler
|
||||
|
||||
__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry']
|
||||
__all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"]
|
||||
|
@@ -19,7 +19,7 @@ def build_from_config(module, config: dict):
|
||||
AssertionError: Raises an AssertionError if `module` is not a class
|
||||
|
||||
"""
|
||||
assert inspect.isclass(module), 'module must be a class'
|
||||
assert inspect.isclass(module), "module must be a class"
|
||||
return module(**config)
|
||||
|
||||
|
||||
@@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry):
|
||||
Raises:
|
||||
Exception: Raises an Exception if an error occurred when building from registry.
|
||||
"""
|
||||
config_ = config.copy() # keep the original config untouched
|
||||
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}'
|
||||
config_ = config.copy() # keep the original config untouched
|
||||
assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}"
|
||||
|
||||
mod_type = config_.pop('type')
|
||||
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}'
|
||||
mod_type = config_.pop("type")
|
||||
assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}"
|
||||
try:
|
||||
obj = registry.get_module(mod_type)(**config_)
|
||||
except Exception as e:
|
||||
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
|
||||
print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True)
|
||||
raise e
|
||||
|
||||
return obj
|
||||
@@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer):
|
||||
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['model'] = model
|
||||
config_['optimizer'] = optimizer
|
||||
config_["model"] = model
|
||||
config_["optimizer"] = optimizer
|
||||
return build_from_registry(config_, GRADIENT_HANDLER)
|
||||
|
Reference in New Issue
Block a user