mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[pipeline] refactor the pipeline module (#1087)
* [pipeline] refactor the pipeline module * polish code
This commit is contained in:
@@ -2,6 +2,5 @@ from .layer import *
|
||||
from .loss import *
|
||||
from .lr_scheduler import *
|
||||
from .metric import *
|
||||
from .model import *
|
||||
from .optimizer import *
|
||||
from ._ops import *
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from .lambda_wrapper import LambdaWrapper
|
||||
from .pipeline_wrapper import PipelineSharedModuleWrapper
|
||||
|
||||
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
|
||||
__all__ = ['PipelineSharedModuleWrapper']
|
||||
|
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
from colossalai.registry import LAYERS
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class LambdaWrapper(nn.Module):
|
||||
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them.
|
||||
|
||||
Args:
|
||||
func (``Callable``): User customed function.
|
||||
layers_cfg (dict, optional): Config of layers, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, func, layers_cfg: dict = None):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
self.layers = self._build_layers(layers_cfg)
|
||||
|
||||
def _build_layers(self, layers_cfg: dict):
|
||||
if layers_cfg is None:
|
||||
return None
|
||||
else:
|
||||
layers = []
|
||||
|
||||
for cfg in layers_cfg:
|
||||
layer = build_layer(cfg)
|
||||
layers.append(layer)
|
||||
return layers
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.func(self, *args, **kwargs)
|
@@ -1,3 +0,0 @@
|
||||
from .model_from_config import ModelFromConfig
|
||||
|
||||
__all__ = ['ModelFromConfig']
|
@@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
|
||||
|
||||
class ModelFromConfig(nn.Module, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(ModelFromConfig, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers_cfg = []
|
||||
|
||||
def build_from_cfg(self, start=None, end=None):
|
||||
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
|
||||
'spelling and if you have initialized this variable'
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None:
|
||||
end = len(self.layers_cfg)
|
||||
for cfg in self.layers_cfg[start: end]:
|
||||
layer = build_layer(cfg)
|
||||
self.layers.append(layer)
|
||||
|
||||
@abstractmethod
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
"""Use this function to override the state dict for
|
||||
saving checkpoints."""
|
||||
return self.state_dict(destination, prefix, keep_vars)
|
Reference in New Issue
Block a user