mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-22 23:32:37 +00:00
update examples and sphnix docs for the new api (#63)
This commit is contained in:
@@ -16,8 +16,8 @@ def build_from_config(module, config: dict):
|
||||
of the return object
|
||||
:type config: dict
|
||||
:raises AssertionError: Raises an AssertionError if `module` is not a class
|
||||
:return: An object of :class:`module`
|
||||
:rtype: :class:`module`
|
||||
:return: An object of interest
|
||||
:rtype: Object
|
||||
"""
|
||||
assert inspect.isclass(module), 'module must be a class'
|
||||
return module(**config)
|
||||
@@ -62,8 +62,8 @@ def build_layer(config):
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`nn.Module`
|
||||
:rtype: :class:`nn.Module`
|
||||
:return: An object of :class:`torch.nn.Module`
|
||||
:rtype: :class:`torch.nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, LAYERS)
|
||||
|
||||
@@ -75,8 +75,8 @@ def build_loss(config):
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`torch.autograd.Function`
|
||||
:rtype: :class:`torch.autograd.Function`
|
||||
:return: An object of :class:`torch.nn.modules.loss._Loss`
|
||||
:rtype: :class:`torch.nn.modules.loss._Loss`
|
||||
"""
|
||||
return build_from_registry(config, LOSSES)
|
||||
|
||||
@@ -87,8 +87,8 @@ def build_model(config):
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`nn.Module`
|
||||
:rtype: :class:`nn.Module`
|
||||
:return: An object of :class:`torch.nn.Module`
|
||||
:rtype: :class:`torch.nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, MODELS)
|
||||
|
||||
@@ -134,8 +134,8 @@ def build_gradient_handler(config, model, optimizer):
|
||||
:type model: :class:`nn.Module`
|
||||
:param optimizer: An optimizer object containing parameters for the gradient handler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:return: An object of :class:`BaseGradientHandler`
|
||||
:rtype: :class:`BaseGradientHandler`
|
||||
:return: An object of :class:`colossalai.engine.BaseGradientHandler`
|
||||
:rtype: :class:`colossalai.engine.BaseGradientHandler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['model'] = model
|
||||
@@ -151,8 +151,8 @@ def build_hooks(config, trainer):
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param trainer: A :class:`Trainer` object containing parameters for the hook
|
||||
:type trainer: :class:`Trainer`
|
||||
:return: An object of :class:`BaseHook`
|
||||
:rtype: :class:`BaseHook`
|
||||
:return: An object of :class:`colossalai.trainer.hooks.BaseHook`
|
||||
:rtype: :class:`colossalai.trainer.hooks.BaseHook`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['trainer'] = trainer
|
||||
@@ -182,8 +182,8 @@ def build_data_sampler(config, dataset):
|
||||
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information
|
||||
used in the construction of the return object
|
||||
:type dataset: :class:`torch.utils.data.Dataset`
|
||||
:return: An object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
:return: An object of :class:`colossalai.utils.data_sampler.BaseSampler`
|
||||
:rtype: :class:`colossalai.utils.data_sampler.BaseSampler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['dataset'] = dataset
|
||||
@@ -200,10 +200,6 @@ def build_lr_scheduler(config, optimizer):
|
||||
:param optimizer: An optimizer object containing parameters for the learning rate
|
||||
scheduler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param total_steps: Number of total steps of the learning rate scheduler
|
||||
:type total_steps: int
|
||||
:param num_steps_per_epoch: number of steps per epoch of the learning rate scheduler
|
||||
:type num_steps_per_epoch: int
|
||||
:return: An object of :class:`torch.optim.lr_scheduler`
|
||||
:rtype: :class:`torch.optim.lr_scheduler`
|
||||
"""
|
||||
|
Reference in New Issue
Block a user