[pipeline] refactor the pipeline module (#1087)

* [pipeline] refactor the pipeline module

* polish code
This commit is contained in:
Frank Lee
2022-06-10 11:27:38 +08:00
committed by GitHub
parent bad5d4c0a1
commit 2b2dc1c86b
29 changed files with 366 additions and 1127 deletions

View File

@@ -2,6 +2,5 @@ from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .model import *
from .optimizer import *
from ._ops import *

View File

@@ -1,4 +1,3 @@
from .lambda_wrapper import LambdaWrapper
from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
__all__ = ['PipelineSharedModuleWrapper']

View File

@@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class LambdaWrapper(nn.Module):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them.
Args:
func (``Callable``): User customed function.
layers_cfg (dict, optional): Config of layers, defaults to None.
"""
def __init__(self, func, layers_cfg: dict = None):
super().__init__()
self.func = func
self.layers = self._build_layers(layers_cfg)
def _build_layers(self, layers_cfg: dict):
if layers_cfg is None:
return None
else:
layers = []
for cfg in layers_cfg:
layer = build_layer(cfg)
layers.append(layer)
return layers
def forward(self, *args, **kwargs):
return self.func(self, *args, **kwargs)

View File

@@ -1,3 +0,0 @@
from .model_from_config import ModelFromConfig
__all__ = ['ModelFromConfig']

View File

@@ -1,37 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.builder import build_layer
class ModelFromConfig(nn.Module, ABC):
def __init__(self):
super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
def build_from_cfg(self, start=None, end=None):
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
'spelling and if you have initialized this variable'
if start is None:
start = 0
if end is None:
end = len(self.layers_cfg)
for cfg in self.layers_cfg[start: end]:
layer = build_layer(cfg)
self.layers.append(layer)
@abstractmethod
def init_weights(self):
pass
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)