mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[pipeline] refactor the pipeline module (#1087)
* [pipeline] refactor the pipeline module * polish code
This commit is contained in:
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.builder import build_ophooks
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
|
@@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.utils.model.pipelinable import PipelinableContext
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from tqdm import tqdm
|
||||
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import transforms
|
||||
try:
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
except:
|
||||
pass
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 60
|
||||
@@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port):
|
||||
|
||||
# craete dataloaders
|
||||
root = Path(os.environ['DATA'])
|
||||
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, pad_if_needed=True, crop=32, resize=32)
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.utils.model.pipelinable import PipelinableContext
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
|
||||
from colossalai.testing import rerun_on_exception
|
||||
|
||||
@@ -33,7 +33,7 @@ def run_pipelinable(rank):
|
||||
model = MLP()
|
||||
|
||||
assert pipelinable.policy == "balanced"
|
||||
pipelinable.load_policy("uniform")
|
||||
pipelinable.policy = "uniform"
|
||||
assert pipelinable.policy == "uniform"
|
||||
pipelinable.to_layer_list()
|
||||
|
@@ -1,2 +0,0 @@
|
||||
from .layers import *
|
||||
from .resnet import VanillaResNet
|
@@ -1,3 +0,0 @@
|
||||
from .basic_block import ResNetBasicBlock
|
||||
from .bottleneck import ResNetBottleneck
|
||||
from .reslayer import ResLayer
|
@@ -1,64 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv3x3
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResNetBasicBlock(nn.Module):
|
||||
"""Basic ResNet block
|
||||
"""
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
@@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv3x3, conv1x1
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResNetBottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
@@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv1x1
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block_type: str,
|
||||
norm_layer_type: str,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
blocks: int,
|
||||
groups: int,
|
||||
base_width: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
dilate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.block = LAYERS.get_module(block_type)
|
||||
self.norm_layer = LAYERS.get_module(norm_layer_type)
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.blocks = blocks
|
||||
self.groups = groups
|
||||
self.dilation = dilation
|
||||
self.base_width = base_width
|
||||
self.dilate = dilate
|
||||
self.stride = stride
|
||||
self.layer = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
norm_layer = self.norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if self.dilate:
|
||||
self.dilation *= self.stride
|
||||
self.stride = 1
|
||||
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
|
||||
norm_layer(self.planes * self.block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = self.planes * self.block.expansion
|
||||
for _ in range(1, self.blocks):
|
||||
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
@@ -1,163 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.registry import MODELS
|
||||
from colossalai.nn.model import ModelFromConfig
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VanillaResNet(ModelFromConfig):
|
||||
"""ResNet from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_cls: int,
|
||||
block_type: str,
|
||||
layers: List[int],
|
||||
norm_layer_type: str = 'BatchNorm2d',
|
||||
in_channels: int = 3,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
zero_init_residual: bool = False,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
dilations=(1, 1, 1, 1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.inplanes = 64
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.blocks = layers
|
||||
self.block_expansion = LAYERS.get_module(block_type).expansion
|
||||
self.dilations = dilations
|
||||
self.reslayer_common_cfg = dict(
|
||||
type='ResLayer',
|
||||
block_type=block_type,
|
||||
norm_layer_type=norm_layer_type,
|
||||
groups=groups,
|
||||
base_width=width_per_group
|
||||
)
|
||||
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
|
||||
self.layers_cfg = [
|
||||
# conv1
|
||||
dict(type='Conv2d',
|
||||
in_channels=in_channels,
|
||||
out_channels=self.inplanes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False),
|
||||
# bn1
|
||||
dict(
|
||||
type=norm_layer_type,
|
||||
num_features=self.inplanes
|
||||
),
|
||||
# relu
|
||||
dict(
|
||||
type='ReLU',
|
||||
inplace=True
|
||||
),
|
||||
# maxpool
|
||||
dict(
|
||||
type='MaxPool2d',
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1
|
||||
),
|
||||
# layer 1
|
||||
dict(
|
||||
inplanes=self.inplanes,
|
||||
planes=64,
|
||||
blocks=self.blocks[0],
|
||||
dilation=self.dilations[0],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 2
|
||||
dict(
|
||||
inplanes=64 * self.block_expansion,
|
||||
planes=128,
|
||||
blocks=self.blocks[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0],
|
||||
dilation=self.dilations[1],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 3
|
||||
dict(
|
||||
inplanes=128 * self.block_expansion,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1],
|
||||
dilation=self.dilations[2],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 4
|
||||
dict(
|
||||
inplanes=256 * self.block_expansion,
|
||||
planes=512,
|
||||
blocks=layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2],
|
||||
dilation=self.dilations[3],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# avg pool
|
||||
dict(
|
||||
type='AdaptiveAvgPool2d',
|
||||
output_size=(1, 1)
|
||||
),
|
||||
# flatten
|
||||
dict(
|
||||
type='LambdaWrapper',
|
||||
func=lambda mod, x: torch.flatten(x, 1)
|
||||
),
|
||||
# linear
|
||||
dict(
|
||||
type='Linear',
|
||||
in_features=512 * self.block_expansion,
|
||||
out_features=num_cls
|
||||
)
|
||||
]
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
@@ -1,21 +0,0 @@
|
||||
import os
|
||||
import model
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
NUM_MICRO_BATCHES = 2
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
block_type='ResNetBasicBlock',
|
||||
layers=[2, 2, 2, 2],
|
||||
num_cls=10)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=4),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
@@ -1,43 +0,0 @@
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from functools import partial
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
|
||||
|
||||
|
||||
def run_partition(rank, world_size, port):
|
||||
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
logger.info('finished initialization')
|
||||
|
||||
# build model
|
||||
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
logger.info('model is created')
|
||||
|
||||
global_context.destroy()
|
||||
logger.info('training finished')
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_partition():
|
||||
world_size = 4
|
||||
run_func = partial(run_partition, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_partition()
|
@@ -8,27 +8,45 @@ from pathlib import Path
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.builder import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_dataloader, print_rank_0
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 4
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||
BATCH_SIZE = 8
|
||||
|
||||
CONFIG=dict(
|
||||
NUM_MICRO_BATCHES=2,
|
||||
parallel = dict(
|
||||
pipeline=dict(size=2),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
def run_schedule(rank, world_size, port):
|
||||
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
|
||||
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
|
||||
|
||||
class Flatten(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return torch.flatten(x, 1)
|
||||
|
||||
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
|
||||
|
||||
print_rank_0('model is created')
|
||||
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
@@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_pipeline_schedule():
|
||||
world_size = 4
|
||||
world_size = 2
|
||||
run_func = partial(run_schedule, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.builder.pipeline import partition_uniform
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.builder.pipeline import partition_uniform
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.builder.pipeline import partition_uniform
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.builder.pipeline import partition_uniform
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
Reference in New Issue
Block a user