[legacy] move builder and registry to legacy (#4603)

This commit is contained in:
Hongxin Liu
2023-09-04 19:56:42 +08:00
parent 8accecd55b
commit ac178ca5c1
65 changed files with 353 additions and 332 deletions

View File

@@ -0,0 +1,3 @@
from .builder import build_from_config, build_from_registry, build_gradient_handler
__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry']

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import inspect
from colossalai.legacy.registry import *
def build_from_config(module, config: dict):
"""Returns an object of :class:`module` constructed from `config`.
Args:
module: A python or user-defined class
config: A python dict containing information used in the construction of the return object
Returns: An ``object`` of interest
Raises:
AssertionError: Raises an AssertionError if `module` is not a class
"""
assert inspect.isclass(module), 'module must be a class'
return module(**config)
def build_from_registry(config, registry: Registry):
r"""Returns an object constructed from `config`, the type of the object
is specified by `registry`.
Note:
the `config` is used to construct the return object such as `LAYERS`, `OPTIMIZERS`
and other support types in `registry`. The `config` should contain
all required parameters of corresponding object. The details of support
types in `registry` and the `mod_type` in `config` could be found in
`registry <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/registry/__init__.py>`_.
Args:
config (dict or :class:`colossalai.context.colossalai.context.Config`): information
used in the construction of the return object.
registry (:class:`Registry`): A registry specifying the type of the return object
Returns:
A Python object specified by `registry`.
Raises:
Exception: Raises an Exception if an error occurred when building from registry.
"""
config_ = config.copy() # keep the original config untouched
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}'
mod_type = config_.pop('type')
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}'
try:
obj = registry.get_module(mod_type)(**config_)
except Exception as e:
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
raise e
return obj
def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``GRADIENT_HANDLER``.
model (:class:`nn.Module`): A model containing parameters for the gradient handler
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns:
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
"""
config_ = config.copy()
config_['model'] = model
config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)

View File

@@ -1,6 +1,6 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce

View File

@@ -1,7 +1,7 @@
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler

View File

@@ -7,7 +7,7 @@ import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler

View File

@@ -1,6 +1,6 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce

View File

@@ -1,4 +1,4 @@
from colossalai.registry import GRADIENT_HANDLER
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler

View File

@@ -0,0 +1,19 @@
import torch.distributed.optim as dist_optim
import torch.nn as nn
import torch.optim as optim
from .registry import Registry
LAYERS = Registry("layers", third_party_library=[nn])
MODELS = Registry("models")
OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim])
DATASETS = Registry("datasets")
DIST_GROUP_INITIALIZER = Registry("dist_group_initializer")
GRADIENT_HANDLER = Registry("gradient_handler")
LOSSES = Registry("losses", third_party_library=[nn])
HOOKS = Registry("hooks")
TRANSFORMS = Registry("transforms")
DATA_SAMPLERS = Registry("data_samplers")
LR_SCHEDULERS = Registry("lr_schedulers")
SCHEDULE = Registry("schedules")
OPHOOKS = Registry("ophooks")

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from types import ModuleType
from typing import List
class Registry:
"""This is a registry class used to register classes and modules so that a universal
object builder can be enabled.
Args:
name (str): The name of the registry .
third_party_library (list, optional):
List of third party libraries which are used in the initialization of the register module.
"""
def __init__(self, name: str, third_party_library: List[ModuleType] = None):
self._name = name
self._registry = dict()
self._third_party_lib = third_party_library
@property
def name(self):
return self._name
def register_module(self, module_class):
"""Registers a module represented in `module_class`.
Args:
module_class (class): The module to be registered.
Returns:
class: The module to be registered, so as to use it normally if via importing.
Raises:
AssertionError: Raises an AssertionError if the module has already been registered before.
"""
module_name = module_class.__name__
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
self._registry[module_name] = module_class
# return so as to use it normally if via importing
return module_class
def get_module(self, module_name: str):
"""Retrieves a module with name `module_name` and returns the module if it has
already been registered before.
Args:
module_name (str): The name of the module to be retrieved.
Returns:
:class:`object`: The retrieved module or None.
Raises:
NameError: Raises a NameError if the module to be retrieved has neither been
registered directly nor as third party modules before.
"""
if module_name in self._registry:
return self._registry[module_name]
elif self._third_party_lib is not None:
for lib in self._third_party_lib:
if hasattr(lib, module_name):
return getattr(lib, module_name)
raise NameError(f'Module {module_name} not found in the registry {self.name}')
def has(self, module_name: str):
"""Searches for a module with name `module_name` and returns a boolean value indicating
whether the module has been registered directly or as third party modules before.
Args:
module_name (str): The name of the module to be searched for.
Returns:
bool: A boolean value indicating whether the module has been registered directly or
as third party modules before.
"""
found_flag = module_name in self._registry
if self._third_party_lib:
for lib in self._third_party_lib:
if hasattr(lib, module_name):
found_flag = True
break
return found_flag

View File

@@ -2,9 +2,9 @@
# -*- encoding: utf-8 -*-
import torch
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook

View File

@@ -7,9 +7,9 @@ from typing import List
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
from colossalai.registry import HOOKS
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
from ._base_hook import BaseHook

View File

@@ -1,6 +1,6 @@
from torch import Tensor
from colossalai.registry import HOOKS
from colossalai.legacy.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook

View File

@@ -10,7 +10,7 @@ import torch.distributed as dist
from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
@@ -356,7 +356,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@@ -367,7 +367,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())