mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[legacy] move builder and registry to legacy (#4603)
This commit is contained in:
3
colossalai/legacy/builder/__init__.py
Normal file
3
colossalai/legacy/builder/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .builder import build_from_config, build_from_registry, build_gradient_handler
|
||||
|
||||
__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry']
|
79
colossalai/legacy/builder/builder.py
Normal file
79
colossalai/legacy/builder/builder.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
|
||||
from colossalai.legacy.registry import *
|
||||
|
||||
|
||||
def build_from_config(module, config: dict):
|
||||
"""Returns an object of :class:`module` constructed from `config`.
|
||||
|
||||
Args:
|
||||
module: A python or user-defined class
|
||||
config: A python dict containing information used in the construction of the return object
|
||||
|
||||
Returns: An ``object`` of interest
|
||||
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if `module` is not a class
|
||||
|
||||
"""
|
||||
assert inspect.isclass(module), 'module must be a class'
|
||||
return module(**config)
|
||||
|
||||
|
||||
def build_from_registry(config, registry: Registry):
|
||||
r"""Returns an object constructed from `config`, the type of the object
|
||||
is specified by `registry`.
|
||||
|
||||
Note:
|
||||
the `config` is used to construct the return object such as `LAYERS`, `OPTIMIZERS`
|
||||
and other support types in `registry`. The `config` should contain
|
||||
all required parameters of corresponding object. The details of support
|
||||
types in `registry` and the `mod_type` in `config` could be found in
|
||||
`registry <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/registry/__init__.py>`_.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.colossalai.context.Config`): information
|
||||
used in the construction of the return object.
|
||||
registry (:class:`Registry`): A registry specifying the type of the return object
|
||||
|
||||
Returns:
|
||||
A Python object specified by `registry`.
|
||||
|
||||
Raises:
|
||||
Exception: Raises an Exception if an error occurred when building from registry.
|
||||
"""
|
||||
config_ = config.copy() # keep the original config untouched
|
||||
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}'
|
||||
|
||||
mod_type = config_.pop('type')
|
||||
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}'
|
||||
try:
|
||||
obj = registry.get_module(mod_type)(**config_)
|
||||
except Exception as e:
|
||||
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
|
||||
raise e
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def build_gradient_handler(config, model, optimizer):
|
||||
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
|
||||
`model` and `optimizer`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``GRADIENT_HANDLER``.
|
||||
model (:class:`nn.Module`): A model containing parameters for the gradient handler
|
||||
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
|
||||
|
||||
Returns:
|
||||
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['model'] = model
|
||||
config_['optimizer'] = optimizer
|
||||
return build_from_registry(config_, GRADIENT_HANDLER)
|
@@ -1,6 +1,6 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
@@ -7,7 +7,7 @@ import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
19
colossalai/legacy/registry/__init__.py
Normal file
19
colossalai/legacy/registry/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch.distributed.optim as dist_optim
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .registry import Registry
|
||||
|
||||
LAYERS = Registry("layers", third_party_library=[nn])
|
||||
MODELS = Registry("models")
|
||||
OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim])
|
||||
DATASETS = Registry("datasets")
|
||||
DIST_GROUP_INITIALIZER = Registry("dist_group_initializer")
|
||||
GRADIENT_HANDLER = Registry("gradient_handler")
|
||||
LOSSES = Registry("losses", third_party_library=[nn])
|
||||
HOOKS = Registry("hooks")
|
||||
TRANSFORMS = Registry("transforms")
|
||||
DATA_SAMPLERS = Registry("data_samplers")
|
||||
LR_SCHEDULERS = Registry("lr_schedulers")
|
||||
SCHEDULE = Registry("schedules")
|
||||
OPHOOKS = Registry("ophooks")
|
82
colossalai/legacy/registry/registry.py
Normal file
82
colossalai/legacy/registry/registry.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
|
||||
|
||||
class Registry:
|
||||
"""This is a registry class used to register classes and modules so that a universal
|
||||
object builder can be enabled.
|
||||
|
||||
Args:
|
||||
name (str): The name of the registry .
|
||||
third_party_library (list, optional):
|
||||
List of third party libraries which are used in the initialization of the register module.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, third_party_library: List[ModuleType] = None):
|
||||
self._name = name
|
||||
self._registry = dict()
|
||||
self._third_party_lib = third_party_library
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def register_module(self, module_class):
|
||||
"""Registers a module represented in `module_class`.
|
||||
|
||||
Args:
|
||||
module_class (class): The module to be registered.
|
||||
Returns:
|
||||
class: The module to be registered, so as to use it normally if via importing.
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||
"""
|
||||
module_name = module_class.__name__
|
||||
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
|
||||
self._registry[module_name] = module_class
|
||||
|
||||
# return so as to use it normally if via importing
|
||||
return module_class
|
||||
|
||||
def get_module(self, module_name: str):
|
||||
"""Retrieves a module with name `module_name` and returns the module if it has
|
||||
already been registered before.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to be retrieved.
|
||||
Returns:
|
||||
:class:`object`: The retrieved module or None.
|
||||
Raises:
|
||||
NameError: Raises a NameError if the module to be retrieved has neither been
|
||||
registered directly nor as third party modules before.
|
||||
"""
|
||||
if module_name in self._registry:
|
||||
return self._registry[module_name]
|
||||
elif self._third_party_lib is not None:
|
||||
for lib in self._third_party_lib:
|
||||
if hasattr(lib, module_name):
|
||||
return getattr(lib, module_name)
|
||||
raise NameError(f'Module {module_name} not found in the registry {self.name}')
|
||||
|
||||
def has(self, module_name: str):
|
||||
"""Searches for a module with name `module_name` and returns a boolean value indicating
|
||||
whether the module has been registered directly or as third party modules before.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to be searched for.
|
||||
Returns:
|
||||
bool: A boolean value indicating whether the module has been registered directly or
|
||||
as third party modules before.
|
||||
"""
|
||||
found_flag = module_name in self._registry
|
||||
|
||||
if self._third_party_lib:
|
||||
for lib in self._third_party_lib:
|
||||
if hasattr(lib, module_name):
|
||||
found_flag = True
|
||||
break
|
||||
|
||||
return found_flag
|
@@ -2,9 +2,9 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
from colossalai.legacy.trainer.hooks import BaseHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils.checkpointing import save_checkpoint
|
||||
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
|
@@ -7,9 +7,9 @@ from typing import List
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
|
@@ -10,7 +10,7 @@ import torch.distributed as dist
|
||||
from colossalai.communication import all_reduce
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
@@ -356,7 +356,7 @@ class ThroughputMetric(Metric):
|
||||
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
|
||||
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
|
||||
@@ -367,7 +367,7 @@ class ThroughputMetric(Metric):
|
||||
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
|
||||
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
|
||||
|
Reference in New Issue
Block a user