mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)
This commit is contained in:
parent
91c327cb44
commit
8f02a88db2
@ -20,10 +20,16 @@ def convert_to_naive_amp(model: nn.Module,
|
|||||||
:return: (model, optimizer)
|
:return: (model, optimizer)
|
||||||
:rtype: Tuple
|
:rtype: Tuple
|
||||||
"""
|
"""
|
||||||
if is_no_pp_or_last_stage():
|
if isinstance(model, nn.ModuleList):
|
||||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
# interleaved pipeline
|
||||||
|
module_list = []
|
||||||
|
for chunk, m in enumerate(model):
|
||||||
|
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||||
|
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||||
|
model = nn.ModuleList(module_list)
|
||||||
else:
|
else:
|
||||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
output_to_fp32 = is_no_pp_or_last_stage()
|
||||||
|
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||||
|
|
||||||
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
|
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
@ -14,7 +14,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
|
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
|
||||||
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
|
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier, is_using_pp)
|
||||||
|
|
||||||
|
|
||||||
def _zero_grad_group_helper(group, set_to_none):
|
def _zero_grad_group_helper(group, set_to_none):
|
||||||
@ -58,7 +58,8 @@ class DynamicGradScaler:
|
|||||||
backoff_factor,
|
backoff_factor,
|
||||||
growth_interval,
|
growth_interval,
|
||||||
hysteresis,
|
hysteresis,
|
||||||
max_scale: int = None):
|
max_scale: int = None,
|
||||||
|
verbose: bool = False):
|
||||||
""""Grad scaler with dynamic scale that gets adjusted
|
""""Grad scaler with dynamic scale that gets adjusted
|
||||||
during training."""
|
during training."""
|
||||||
assert initial_scale > 0.0
|
assert initial_scale > 0.0
|
||||||
@ -91,6 +92,7 @@ class DynamicGradScaler:
|
|||||||
self._hysteresis_tracker = self.hysteresis
|
self._hysteresis_tracker = self.hysteresis
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self):
|
def scale(self):
|
||||||
@ -111,7 +113,8 @@ class DynamicGradScaler:
|
|||||||
if self._hysteresis_tracker <= 0:
|
if self._hysteresis_tracker <= 0:
|
||||||
self._scale = torch.max(self._scale * self.backoff_factor,
|
self._scale = torch.max(self._scale * self.backoff_factor,
|
||||||
self.min_scale)
|
self.min_scale)
|
||||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
if self.verbose:
|
||||||
|
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||||
else:
|
else:
|
||||||
# If there is no nan/inf, increment the growth tracker.
|
# If there is no nan/inf, increment the growth tracker.
|
||||||
self._growth_tracker += 1
|
self._growth_tracker += 1
|
||||||
@ -122,11 +125,14 @@ class DynamicGradScaler:
|
|||||||
self._hysteresis_tracker = self.hysteresis
|
self._hysteresis_tracker = self.hysteresis
|
||||||
# and scale up the loss scale.
|
# and scale up the loss scale.
|
||||||
if self._max_scale is not None and self._scale >= self._max_scale:
|
if self._max_scale is not None and self._scale >= self._max_scale:
|
||||||
self._logger.info(
|
if self.verbose:
|
||||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
self._logger.info(
|
||||||
|
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
||||||
else:
|
else:
|
||||||
self._scale = self._scale * self.growth_factor
|
self._scale = self._scale * self.growth_factor
|
||||||
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
if self.verbose:
|
||||||
|
self._logger.info(
|
||||||
|
f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
@ -162,6 +168,8 @@ class FP16Optimizer(Optimizer):
|
|||||||
:type hysterisis: int
|
:type hysterisis: int
|
||||||
:param max_scale: maximum loss scale allowed
|
:param max_scale: maximum loss scale allowed
|
||||||
:type max_scale: int
|
:type max_scale: int
|
||||||
|
:param verbose: if set to `True`, will print debug info
|
||||||
|
:type verbose: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -174,27 +182,29 @@ class FP16Optimizer(Optimizer):
|
|||||||
backoff_factor=0.5,
|
backoff_factor=0.5,
|
||||||
growth_interval=1000,
|
growth_interval=1000,
|
||||||
hysteresis=2,
|
hysteresis=2,
|
||||||
max_scale: int = 2 ** 32):
|
max_scale: int = 2 ** 32,
|
||||||
|
verbose: bool = False):
|
||||||
# default args for compatibility
|
# default args for compatibility
|
||||||
bf16 = False
|
bf16 = False
|
||||||
params_have_main_grad = True
|
params_have_main_grad = False
|
||||||
|
|
||||||
# have a defaults for compatibility with pytorch optim
|
# have a defaults for compatibility with pytorch optim
|
||||||
self.defaults = optimizer.defaults
|
self.defaults = optimizer.defaults
|
||||||
|
|
||||||
# log config
|
# log config
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
if verbose:
|
||||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
||||||
f"clip_grad = {clip_grad}\n"
|
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||||
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
|
f"clip_grad = {clip_grad}\n"
|
||||||
f"initial_scale = {initial_scale}\n"
|
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
|
||||||
f"min_scale = {min_scale}\n"
|
f"initial_scale = {initial_scale}\n"
|
||||||
f"growth_factor = {growth_factor}\n"
|
f"min_scale = {min_scale}\n"
|
||||||
f"backoff_factor = {backoff_factor}\n"
|
f"growth_factor = {growth_factor}\n"
|
||||||
f"growth_interval = {growth_interval}\n"
|
f"backoff_factor = {backoff_factor}\n"
|
||||||
f"hysteresis = {hysteresis}\n"
|
f"growth_interval = {growth_interval}\n"
|
||||||
f"==========================================", ranks=[0])
|
f"hysteresis = {hysteresis}\n"
|
||||||
|
f"==========================================", ranks=[0])
|
||||||
|
|
||||||
"""Input optimizer is the base optimizer for example Adam."""
|
"""Input optimizer is the base optimizer for example Adam."""
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -212,7 +222,8 @@ class FP16Optimizer(Optimizer):
|
|||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
growth_interval=growth_interval,
|
growth_interval=growth_interval,
|
||||||
hysteresis=hysteresis,
|
hysteresis=hysteresis,
|
||||||
max_scale=max_scale
|
max_scale=max_scale,
|
||||||
|
verbose=verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
# None grad scaler is only supported for bf16.
|
# None grad scaler is only supported for bf16.
|
||||||
@ -350,6 +361,11 @@ class FP16Optimizer(Optimizer):
|
|||||||
op=torch.distributed.ReduceOp.MAX,
|
op=torch.distributed.ReduceOp.MAX,
|
||||||
group=gpc.get_group(ParallelMode.TENSOR))
|
group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
|
||||||
|
if is_using_pp():
|
||||||
|
torch.distributed.all_reduce(self.found_inf,
|
||||||
|
op=torch.distributed.ReduceOp.MAX,
|
||||||
|
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
# Check for nan.
|
# Check for nan.
|
||||||
found_inf_flag = (self.found_inf.item() > 0)
|
found_inf_flag = (self.found_inf.item() > 0)
|
||||||
return found_inf_flag
|
return found_inf_flag
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
|
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
|
||||||
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||||
build_gradient_handler)
|
build_gradient_handler)
|
||||||
from .pipeline import PipelineModelInitializer
|
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
||||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||||
'build_gradient_handler', 'PipelineModelInitializer'
|
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
|
||||||
]
|
]
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import copy
|
import copy
|
||||||
import heapq
|
import heapq
|
||||||
|
|
||||||
|
|
||||||
from colossalai.builder import build_model, build_layer
|
from colossalai.builder import build_model, build_layer
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import set_to_cuda
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def _binary_partition(weights, st, ed):
|
def _binary_partition(weights, st, ed):
|
||||||
@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
|||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
class PipelineModelInitializer():
|
def _count_layer_params(layers):
|
||||||
|
"""Count the number of parameters in each layer
|
||||||
|
"""
|
||||||
|
param_counts = [0] * len(layers)
|
||||||
|
for idx, cfg in enumerate(layers):
|
||||||
|
layer = build_layer(cfg)
|
||||||
|
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||||||
|
param_counts[idx] = sum(p.numel() for p in params)
|
||||||
|
|
||||||
|
return param_counts
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False):
|
||||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||||
|
|
||||||
An example for the model config is shown below. The class VisionTransformerFromConfig should
|
An example for the model config is shown below. The class VisionTransformerFromConfig should
|
||||||
@ -168,88 +181,86 @@ class PipelineModelInitializer():
|
|||||||
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||||
in most cases unless you are using virutal pipeline parallelism.
|
in most cases unless you are using virutal pipeline parallelism.
|
||||||
:type num_chunks: int
|
:type num_chunks: int
|
||||||
|
:param partition_method: this parameter determines how you want to split your model layers into stages,
|
||||||
|
you can set it as 'layer' or 'parameter'
|
||||||
|
:type partition_method: str
|
||||||
:param verbose: whether to print the logs
|
:param verbose: whether to print the logs
|
||||||
:type verbose: bool
|
:type verbose: bool
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
ori_model = build_model(config)
|
||||||
|
layers = ori_model.layers_cfg
|
||||||
|
layer_length = len(layers)
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||||
|
|
||||||
def __init__(self, config, num_chunks, verbose=False):
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
self.num_chunks = num_chunks
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
self.ori_model = build_model(config)
|
|
||||||
self.layers = self.ori_model.layers_cfg
|
|
||||||
layer_length = len(self.layers)
|
|
||||||
self.verbose = verbose
|
|
||||||
self._logger = get_dist_logger()
|
|
||||||
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
|
||||||
|
|
||||||
def initialize(self, partition_method='parameter'):
|
method = partition_method.lower()
|
||||||
"""Initialize the model object from the config passed
|
# Make a partition
|
||||||
|
if method == 'layer':
|
||||||
|
num_layers = len(layers)
|
||||||
|
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
||||||
|
elif method == 'parameter':
|
||||||
|
param_counts = _count_layer_params(layers)
|
||||||
|
# print_rank_0(param_counts)
|
||||||
|
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
||||||
|
|
||||||
:param partition_method: this parameter determines how you want to split your model layers into stages,
|
# Display the partition
|
||||||
you can set it as 'layer' or 'parameter'
|
if verbose:
|
||||||
:type partition_method: str
|
log_str = 'Layer allocation after partitioning: \n'
|
||||||
|
for stage in range(pipeline_parallel_size):
|
||||||
"""
|
|
||||||
# Some space for initializing comunication groups
|
|
||||||
self._interval = None
|
|
||||||
self._partition_layers(method=partition_method)
|
|
||||||
models = self._build()
|
|
||||||
model = set_to_cuda(models)
|
|
||||||
|
|
||||||
return model
|
num_layers = 0
|
||||||
|
for st, ed in parts[stage]:
|
||||||
|
num_layers += ed - st
|
||||||
|
|
||||||
def _partition_layers(self, method):
|
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
for st, ed in parts[stage]:
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
for idx, layer in enumerate(layers[st: ed]):
|
||||||
|
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||||
|
logger.info(log_str, ranks=[0])
|
||||||
|
|
||||||
method = method.lower()
|
# Save the partition
|
||||||
# Make a partition
|
interval = parts[pipeline_rank]
|
||||||
if method == 'layer':
|
|
||||||
num_layers = len(self.layers)
|
|
||||||
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
|
|
||||||
elif method == 'parameter':
|
|
||||||
param_counts = self._count_layer_params()
|
|
||||||
# print_rank_0(param_counts)
|
|
||||||
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
|
|
||||||
else:
|
|
||||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
|
||||||
|
|
||||||
# Display the partition
|
models = []
|
||||||
if gpc.get_global_rank() == 0 and self.verbose:
|
for st, ed in interval:
|
||||||
log_str = 'Layer allocation after partitioning: \n'
|
model = copy.deepcopy(ori_model)
|
||||||
for stage in range(pipeline_parallel_size):
|
model.build_from_cfg(st, ed)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
num_layers = 0
|
return nn.ModuleList(models) if len(models) > 1 else models[0]
|
||||||
for st, ed in self.parts[stage]:
|
|
||||||
num_layers += ed - st
|
|
||||||
|
|
||||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
|
||||||
for st, ed in self.parts[stage]:
|
|
||||||
for idx, layer in enumerate(self.layers[st: ed]):
|
|
||||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
|
||||||
self._logger.info(log_str, ranks=[0])
|
|
||||||
|
|
||||||
# Save the partition
|
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
||||||
self._interval = self.parts[pipeline_rank]
|
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||||
|
Note that `layer` must be `torch.nn.Sequential`.
|
||||||
|
|
||||||
def _build(self):
|
:param layers: layers of model
|
||||||
"""Build model from the layer cfg according to the partition
|
:type config: `torch.nn.Sequential`
|
||||||
"""
|
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||||
models = []
|
in most cases unless you are using virutal pipeline parallelism.
|
||||||
for st, ed in self._interval:
|
:type num_chunks: int
|
||||||
model = copy.copy(self.ori_model)
|
:param verbose: whether to print the logs
|
||||||
model.build_from_cfg(st, ed)
|
:type verbose: bool
|
||||||
models.append(model)
|
"""
|
||||||
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
return models
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||||
def _count_layer_params(self):
|
module_list = []
|
||||||
"""Count the number of parameters in each layer
|
for start, end in partitions[pipeline_rank]:
|
||||||
"""
|
module_list.append(nn.Sequential(*layers[start:end]))
|
||||||
param_counts = [0] * len(self.layers)
|
if verbose:
|
||||||
for idx, cfg in enumerate(self.layers):
|
logger = get_dist_logger()
|
||||||
layer = build_layer(cfg)
|
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
||||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
for rank, part in enumerate(partitions):
|
||||||
param_counts[idx] = sum(p.numel() for p in params)
|
log_str = f'===== stage={rank} =====\n'
|
||||||
|
for chunk, (start, end) in enumerate(part):
|
||||||
return param_counts
|
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
|
||||||
|
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
||||||
|
logger.info(log_str, ranks=[0])
|
||||||
|
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
||||||
|
@ -63,9 +63,6 @@ def _communicate(tensor_send_next=None,
|
|||||||
next_rank = gpc.get_next_global_rank(
|
next_rank = gpc.get_next_global_rank(
|
||||||
ParallelMode.PIPELINE)
|
ParallelMode.PIPELINE)
|
||||||
|
|
||||||
# rank = dist.get_rank()
|
|
||||||
rank = gpc.get_global_rank()
|
|
||||||
|
|
||||||
ops = []
|
ops = []
|
||||||
if tensor_send_prev is not None:
|
if tensor_send_prev is not None:
|
||||||
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
||||||
@ -88,7 +85,7 @@ def _communicate(tensor_send_next=None,
|
|||||||
return tensor_recv_prev, tensor_recv_next
|
return tensor_recv_prev, tensor_recv_next
|
||||||
|
|
||||||
|
|
||||||
def recv_forward(input_tensor_shape, prev_rank=None):
|
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
|
||||||
"""Receives the input tensor from the previous member in pipeline.
|
"""Receives the input tensor from the previous member in pipeline.
|
||||||
|
|
||||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||||
@ -98,16 +95,17 @@ def recv_forward(input_tensor_shape, prev_rank=None):
|
|||||||
:return: The input tensor in forward step
|
:return: The input tensor in forward step
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
if gpc.is_pipeline_first_stage():
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
else:
|
else:
|
||||||
input_tensor, _ = _communicate(recv_prev=True,
|
input_tensor, _ = _communicate(recv_prev=True,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank)
|
prev_rank=prev_rank,
|
||||||
|
dtype=dtype)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
def recv_backward(output_grad_shape, next_rank=None):
|
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
|
||||||
"""Receives the grad tensor from the next member in pipeline.
|
"""Receives the grad tensor from the next member in pipeline.
|
||||||
|
|
||||||
:param output_grad_shape: The shape of the tensor to be recieved
|
:param output_grad_shape: The shape of the tensor to be recieved
|
||||||
@ -117,12 +115,13 @@ def recv_backward(output_grad_shape, next_rank=None):
|
|||||||
:return: The grad of output tensor in forward step
|
:return: The grad of output tensor in forward step
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_pipeline_last_stage():
|
||||||
output_tensor_grad = None
|
output_tensor_grad = None
|
||||||
else:
|
else:
|
||||||
_, output_tensor_grad = _communicate(recv_next=True,
|
_, output_tensor_grad = _communicate(recv_next=True,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
dtype=dtype)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
@ -134,7 +133,7 @@ def send_forward(output_tensor, next_rank=None):
|
|||||||
:type output_tensor: :class:`torch.Tensor`
|
:type output_tensor: :class:`torch.Tensor`
|
||||||
:type next_rank: int, optional
|
:type next_rank: int, optional
|
||||||
"""
|
"""
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_pipeline_last_stage():
|
||||||
_communicate(tensor_send_next=output_tensor,
|
_communicate(tensor_send_next=output_tensor,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank)
|
||||||
|
|
||||||
@ -147,7 +146,7 @@ def send_backward(input_tensor_grad, prev_rank=None):
|
|||||||
:type input_tensor_grad: :class:`torch.Tensor`
|
:type input_tensor_grad: :class:`torch.Tensor`
|
||||||
:type prev_rank: int, optional
|
:type prev_rank: int, optional
|
||||||
"""
|
"""
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_pipeline_first_stage():
|
||||||
_communicate(tensor_send_prev=input_tensor_grad,
|
_communicate(tensor_send_prev=input_tensor_grad,
|
||||||
prev_rank=prev_rank)
|
prev_rank=prev_rank)
|
||||||
|
|
||||||
@ -155,7 +154,8 @@ def send_backward(input_tensor_grad, prev_rank=None):
|
|||||||
def send_forward_recv_backward(output_tensor,
|
def send_forward_recv_backward(output_tensor,
|
||||||
output_grad_shape,
|
output_grad_shape,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
next_rank=None):
|
next_rank=None,
|
||||||
|
dtype=torch.float):
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next member in pipeline, while recieves the grad tensor from the
|
next member in pipeline, while recieves the grad tensor from the
|
||||||
next member in pipeline.
|
next member in pipeline.
|
||||||
@ -167,20 +167,22 @@ def send_forward_recv_backward(output_tensor,
|
|||||||
:return: The grad of output tensor in forward step
|
:return: The grad of output tensor in forward step
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_pipeline_last_stage():
|
||||||
output_tensor_grad = None
|
output_tensor_grad = None
|
||||||
else:
|
else:
|
||||||
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
dtype=dtype)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
def send_backward_recv_forward(input_tensor_grad,
|
def send_backward_recv_forward(input_tensor_grad,
|
||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None):
|
prev_rank=None,
|
||||||
|
dtype=torch.float):
|
||||||
"""Batched communication operation. Sends the grad tensor to the
|
"""Batched communication operation. Sends the grad tensor to the
|
||||||
previous member in pipeline, while recieves the input tensor from the
|
previous member in pipeline, while recieves the input tensor from the
|
||||||
previous member in pipeline.
|
previous member in pipeline.
|
||||||
@ -192,13 +194,14 @@ def send_backward_recv_forward(input_tensor_grad,
|
|||||||
:return: The input tensor in forward step
|
:return: The input tensor in forward step
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
if gpc.is_pipeline_first_stage():
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
else:
|
else:
|
||||||
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
|
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank)
|
prev_rank=prev_rank,
|
||||||
|
dtype=dtype)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
@ -206,7 +209,8 @@ def send_forward_recv_forward(output_tensor,
|
|||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None):
|
next_rank=None,
|
||||||
|
dtype=torch.float):
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next member in pipeline, while recieves the input tensor from the
|
next member in pipeline, while recieves the input tensor from the
|
||||||
previous member in pipeline.
|
previous member in pipeline.
|
||||||
@ -222,7 +226,8 @@ def send_forward_recv_forward(output_tensor,
|
|||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
dtype=dtype)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
@ -230,7 +235,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
|||||||
output_grad_shape,
|
output_grad_shape,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None):
|
next_rank=None,
|
||||||
|
dtype=torch.float):
|
||||||
"""Batched communication operation. Sends the grad tensor to the
|
"""Batched communication operation. Sends the grad tensor to the
|
||||||
previous member in pipeline, while recieves the grad tensor from the
|
previous member in pipeline, while recieves the grad tensor from the
|
||||||
next member in pipeline.
|
next member in pipeline.
|
||||||
@ -246,7 +252,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
|||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
dtype=dtype)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +264,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None):
|
next_rank=None,
|
||||||
|
dtype=torch.float):
|
||||||
"""Batched communication operation. Sends the input tensor to the next and
|
"""Batched communication operation. Sends the input tensor to the next and
|
||||||
the grad tensor to the previous, while recieves the grad tensor from the
|
the grad tensor to the previous, while recieves the grad tensor from the
|
||||||
next and the input tensor from the previous.
|
next and the input tensor from the previous.
|
||||||
@ -281,5 +289,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
dtype=dtype)
|
||||||
return input_tensor, output_tensor_grad
|
return input_tensor, output_tensor_grad
|
||||||
|
@ -29,14 +29,8 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
|||||||
|
|
||||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
||||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
||||||
ops = [
|
dist.send(send_ndims, next_rank)
|
||||||
dist.P2POp(dist.isend, send_ndims, next_rank),
|
dist.send(send_shape, next_rank)
|
||||||
dist.P2POp(dist.isend, send_shape, next_rank)
|
|
||||||
]
|
|
||||||
reqs = dist.batch_isend_irecv(ops)
|
|
||||||
for req in reqs:
|
|
||||||
req.wait()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -53,6 +53,8 @@ class ParallelContext:
|
|||||||
self.data_parallel_size = 1
|
self.data_parallel_size = 1
|
||||||
self.pipeline_parallel_size = 1
|
self.pipeline_parallel_size = 1
|
||||||
self.tensor_parallel_size = 1
|
self.tensor_parallel_size = 1
|
||||||
|
self.virtual_pipeline_parallel_size = None
|
||||||
|
self.virtual_pipeline_parallel_rank = None
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
self._verbose = False
|
self._verbose = False
|
||||||
@ -205,6 +207,18 @@ class ParallelContext:
|
|||||||
world_size = self.get_world_size(parallel_mode)
|
world_size = self.get_world_size(parallel_mode)
|
||||||
return rank == world_size - 1
|
return rank == world_size - 1
|
||||||
|
|
||||||
|
def is_pipeline_first_stage(self, ignore_virtual=False):
|
||||||
|
if not ignore_virtual:
|
||||||
|
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
|
||||||
|
return False
|
||||||
|
return self.is_first_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
def is_pipeline_last_stage(self, ignore_virtual=False):
|
||||||
|
if not ignore_virtual:
|
||||||
|
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
|
||||||
|
return False
|
||||||
|
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
def get_world_size(self, parallel_mode: ParallelMode):
|
def get_world_size(self, parallel_mode: ParallelMode):
|
||||||
"""Returns the world size for `parallel_mode`.
|
"""Returns the world size for `parallel_mode`.
|
||||||
|
|
||||||
@ -494,3 +508,9 @@ class ParallelContext:
|
|||||||
self._logger.info(
|
self._logger.info(
|
||||||
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
|
def set_virtual_pipeline_parallel_size(self, size):
|
||||||
|
self.virtual_pipeline_parallel_size = size
|
||||||
|
|
||||||
|
def set_virtual_pipeline_parallel_rank(self, rank):
|
||||||
|
self.virtual_pipeline_parallel_rank = rank
|
||||||
|
@ -32,7 +32,7 @@ class DataParallelGradientHandler(BaseGradientHandler):
|
|||||||
if tp not in buckets:
|
if tp not in buckets:
|
||||||
buckets[tp] = []
|
buckets[tp] = []
|
||||||
buckets[tp].append(param)
|
buckets[tp].append(param)
|
||||||
param.main_grad = param.grad
|
# param.main_grad = param.grad
|
||||||
|
|
||||||
# For each bucket, all-reduce and copy all-reduced grads.
|
# For each bucket, all-reduce and copy all-reduced grads.
|
||||||
for tp in buckets:
|
for tp in buckets:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from ._pipeline_schedule import PipelineSchedule
|
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
|
||||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||||
|
|
||||||
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']
|
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule', 'InterleavedPipelineSchedule']
|
||||||
|
@ -13,9 +13,8 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
|
|
||||||
|
|
||||||
def squeeze(x: Union[Tensor, tuple, list]):
|
def squeeze(x: Union[Tensor, tuple, list]):
|
||||||
@ -47,6 +46,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
|
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.sync_data = sync_data
|
self.sync_data = sync_data
|
||||||
|
self.dtype = torch.float
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
def _move_to_device(self, data):
|
||||||
if isinstance(data, (
|
if isinstance(data, (
|
||||||
@ -122,12 +122,8 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||||
)
|
)
|
||||||
|
|
||||||
# LSG: set default dtype to fp16 for communication
|
|
||||||
if isinstance(engine.model, NaiveAMPModel):
|
if isinstance(engine.model, NaiveAMPModel):
|
||||||
torch.set_default_dtype(torch.half)
|
self.dtype = torch.half
|
||||||
self.logger.warning(
|
|
||||||
'default tensor dtype is set to torch.half for fp16 training',
|
|
||||||
ranks=[0])
|
|
||||||
|
|
||||||
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
@ -140,7 +136,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
:type input_tensor: :class:`torch.Tensor`
|
:type input_tensor: :class:`torch.Tensor`
|
||||||
:param return_tensors: a list of tensors to return
|
:param return_tensors: a list of tensors to return
|
||||||
:type return_tensors: List[:class:`torch.Tensor`]
|
:type return_tensors: List[:class:`torch.Tensor`]
|
||||||
|
|
||||||
:return: output or the loss value of the current pipeline stage
|
:return: output or the loss value of the current pipeline stage
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
@ -252,7 +248,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
for i in range(num_warmup_microbatches):
|
for i in range(num_warmup_microbatches):
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape)
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||||
output_tensor = self.forward_step(
|
output_tensor = self.forward_step(
|
||||||
engine, input_tensor, return_tensors,
|
engine, input_tensor, return_tensors,
|
||||||
return_loss=return_loss
|
return_loss=return_loss
|
||||||
@ -272,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
if num_microbatches_remaining > 0:
|
if num_microbatches_remaining > 0:
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape)
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||||
|
|
||||||
# Run 1F1B in steady state.
|
# Run 1F1B in steady state.
|
||||||
for i in range(num_microbatches_remaining):
|
for i in range(num_microbatches_remaining):
|
||||||
@ -286,11 +282,11 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
send_forward(output_tensor)
|
send_forward(output_tensor)
|
||||||
|
|
||||||
if not last_iteration:
|
if not last_iteration:
|
||||||
input_tensor = recv_forward(ft_shape)
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = send_forward_recv_backward(
|
output_tensor_grad = send_forward_recv_backward(
|
||||||
output_tensor, bt_shape)
|
output_tensor, bt_shape, dtype=self.dtype)
|
||||||
|
|
||||||
# Add input_tensor and output_tensor to end of list.
|
# Add input_tensor and output_tensor to end of list.
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
@ -312,7 +308,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
send_backward(input_tensor_grad)
|
send_backward(input_tensor_grad)
|
||||||
else:
|
else:
|
||||||
input_tensor = send_backward_recv_forward(
|
input_tensor = send_backward_recv_forward(
|
||||||
input_tensor_grad, ft_shape)
|
input_tensor_grad, ft_shape, dtype=self.dtype)
|
||||||
|
|
||||||
# Run cooldown backward passes.
|
# Run cooldown backward passes.
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
@ -320,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
input_tensor = input_tensors.pop(0)
|
input_tensor = input_tensors.pop(0)
|
||||||
output_tensor = output_tensors.pop(0)
|
output_tensor = output_tensors.pop(0)
|
||||||
|
|
||||||
output_tensor_grad = recv_backward(bt_shape)
|
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(
|
input_tensor_grad = self.backward_step(
|
||||||
engine,
|
engine,
|
||||||
@ -340,3 +336,309 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||||
else:
|
else:
|
||||||
return tuple((None, None, None))
|
return tuple((None, None, None))
|
||||||
|
|
||||||
|
|
||||||
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
|
||||||
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||||
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||||
|
super().__init__(num_microbatches, sync_data=sync_data)
|
||||||
|
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
|
|
||||||
|
def pre_processing(self, engine):
|
||||||
|
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||||
|
raise TypeError(
|
||||||
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(engine.model[0], NaiveAMPModel):
|
||||||
|
self.dtype = torch.half
|
||||||
|
|
||||||
|
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
|
||||||
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if input_tensor is None:
|
||||||
|
input_tensor, label = self.load_micro_batch()
|
||||||
|
input_tensor = squeeze(input_tensor)
|
||||||
|
output_tensor = model(input_tensor)
|
||||||
|
output_tensor = squeeze(output_tensor)
|
||||||
|
|
||||||
|
if gpc.is_pipeline_last_stage():
|
||||||
|
if return_loss:
|
||||||
|
input_tensor, label = self.load_micro_batch()
|
||||||
|
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
|
||||||
|
return_tensors.append(
|
||||||
|
tuple((output_tensor, label[0], loss_reduced)))
|
||||||
|
return loss_reduced
|
||||||
|
else:
|
||||||
|
return_tensors.append(output_tensor)
|
||||||
|
return output_tensor
|
||||||
|
else:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
|
||||||
|
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||||
|
communication between pipeline stages as needed.
|
||||||
|
|
||||||
|
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||||
|
assert forward_only or return_loss, \
|
||||||
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
self.load_batch(data_iter)
|
||||||
|
model = engine.model
|
||||||
|
input_tensors = [[] for _ in range(len(model))]
|
||||||
|
output_tensors = [[] for _ in range(len(model))]
|
||||||
|
return_tensors = []
|
||||||
|
if not forward_only:
|
||||||
|
output_tensor_grads = [[] for _ in range(len(model))]
|
||||||
|
|
||||||
|
# Used for tensor meta information communication
|
||||||
|
input_tensor_shapes = [None for _ in range(len(model))]
|
||||||
|
output_tensor_shapes = [None for _ in range(len(model))]
|
||||||
|
send_tensor_shape_flags = [True for _ in range(len(model))]
|
||||||
|
|
||||||
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
|
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
# Compute number of warmup and remaining microbatches.
|
||||||
|
num_model_chunks = len(model)
|
||||||
|
num_microbatches = self.num_microbatches * num_model_chunks
|
||||||
|
all_warmup_microbatches = False
|
||||||
|
if forward_only:
|
||||||
|
num_warmup_microbatches = num_microbatches
|
||||||
|
else:
|
||||||
|
# Run all forward passes and then all backward passes if number of
|
||||||
|
# microbatches is just the number of pipeline stages.
|
||||||
|
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
||||||
|
# all workers, followed by more microbatches after depending on
|
||||||
|
# stage ID (more forward passes for earlier stages, later stages can
|
||||||
|
# immediately start with 1F1B).
|
||||||
|
if self.num_microbatches == pipeline_parallel_size:
|
||||||
|
num_warmup_microbatches = num_microbatches
|
||||||
|
all_warmup_microbatches = True
|
||||||
|
else:
|
||||||
|
num_warmup_microbatches = \
|
||||||
|
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
||||||
|
num_warmup_microbatches += (
|
||||||
|
num_model_chunks - 1) * pipeline_parallel_size
|
||||||
|
num_warmup_microbatches = min(num_warmup_microbatches,
|
||||||
|
num_microbatches)
|
||||||
|
num_microbatches_remaining = \
|
||||||
|
num_microbatches - num_warmup_microbatches
|
||||||
|
|
||||||
|
def get_model_chunk_id(microbatch_id, forward):
|
||||||
|
"""Helper method to get the model chunk ID given the iteration number."""
|
||||||
|
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
||||||
|
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
||||||
|
if not forward:
|
||||||
|
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
||||||
|
return model_chunk_id
|
||||||
|
|
||||||
|
def forward_step_helper(microbatch_id):
|
||||||
|
"""Helper method to run forward step with model split into chunks
|
||||||
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||||
|
forward_step())."""
|
||||||
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
||||||
|
|
||||||
|
# forward step
|
||||||
|
if gpc.is_pipeline_first_stage():
|
||||||
|
if len(input_tensors[model_chunk_id]) == \
|
||||||
|
len(output_tensors[model_chunk_id]):
|
||||||
|
input_tensors[model_chunk_id].append(None)
|
||||||
|
input_tensor = input_tensors[model_chunk_id][-1]
|
||||||
|
output_tensor = self.forward_step(
|
||||||
|
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
|
||||||
|
output_tensors[model_chunk_id].append(output_tensor)
|
||||||
|
|
||||||
|
# if forward-only, no need to save tensors for a backward pass
|
||||||
|
if forward_only:
|
||||||
|
input_tensors[model_chunk_id].pop()
|
||||||
|
output_tensors[model_chunk_id].pop()
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def backward_step_helper(microbatch_id):
|
||||||
|
"""Helper method to run backward step with model split into chunks
|
||||||
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||||
|
backward_step())."""
|
||||||
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
||||||
|
|
||||||
|
if gpc.is_pipeline_last_stage():
|
||||||
|
if len(output_tensor_grads[model_chunk_id]) == 0:
|
||||||
|
output_tensor_grads[model_chunk_id].append(None)
|
||||||
|
input_tensor = input_tensors[model_chunk_id].pop(0)
|
||||||
|
output_tensor = output_tensors[model_chunk_id].pop(0)
|
||||||
|
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
||||||
|
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||||
|
|
||||||
|
return input_tensor_grad
|
||||||
|
|
||||||
|
# Run warmup forward passes.
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
|
if not gpc.is_pipeline_first_stage():
|
||||||
|
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
|
||||||
|
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
|
||||||
|
|
||||||
|
for k in range(num_warmup_microbatches):
|
||||||
|
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||||
|
output_tensor = forward_step_helper(k)
|
||||||
|
if not gpc.is_pipeline_last_stage():
|
||||||
|
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||||
|
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
|
||||||
|
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
||||||
|
# Determine if tensor should be received from previous stage.
|
||||||
|
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
||||||
|
recv_prev = True
|
||||||
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
||||||
|
if next_forward_model_chunk_id == 0:
|
||||||
|
recv_prev = False
|
||||||
|
if k == (num_microbatches - 1):
|
||||||
|
recv_prev = False
|
||||||
|
|
||||||
|
# Don't send tensor downstream if on last stage.
|
||||||
|
if gpc.is_pipeline_last_stage():
|
||||||
|
output_tensor = None
|
||||||
|
|
||||||
|
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||||
|
if not gpc.is_pipeline_first_stage():
|
||||||
|
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
|
||||||
|
input_tensor_shapes[next_forward_model_chunk_id])
|
||||||
|
# Send and receive tensors as appropriate (send tensors computed
|
||||||
|
# in this iteration; receive tensors for next iteration).
|
||||||
|
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||||
|
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
||||||
|
not all_warmup_microbatches:
|
||||||
|
input_tensor_grad = None
|
||||||
|
recv_next = True
|
||||||
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
recv_next = False
|
||||||
|
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
||||||
|
input_tensor, output_tensor_grad = \
|
||||||
|
send_forward_backward_recv_forward_backward(
|
||||||
|
output_tensor, input_tensor_grad,
|
||||||
|
input_shape,
|
||||||
|
output_shape,
|
||||||
|
recv_prev=recv_prev, recv_next=recv_next,
|
||||||
|
dtype=self.dtype)
|
||||||
|
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
||||||
|
else:
|
||||||
|
input_tensor = \
|
||||||
|
send_forward_recv_forward(
|
||||||
|
output_tensor,
|
||||||
|
input_shape,
|
||||||
|
recv_prev=recv_prev,
|
||||||
|
dtype=self.dtype)
|
||||||
|
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||||
|
|
||||||
|
# Run 1F1B in steady state.
|
||||||
|
for k in range(num_microbatches_remaining):
|
||||||
|
# Forward pass.
|
||||||
|
forward_k = k + num_warmup_microbatches
|
||||||
|
output_tensor = forward_step_helper(forward_k)
|
||||||
|
|
||||||
|
# Backward pass.
|
||||||
|
backward_k = k
|
||||||
|
input_tensor_grad = backward_step_helper(backward_k)
|
||||||
|
|
||||||
|
# Send output_tensor and input_tensor_grad, receive input_tensor
|
||||||
|
# and output_tensor_grad.
|
||||||
|
|
||||||
|
# Determine if current stage has anything to send in either direction,
|
||||||
|
# otherwise set tensor to None.
|
||||||
|
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
|
||||||
|
if gpc.is_pipeline_last_stage():
|
||||||
|
output_tensor = None
|
||||||
|
|
||||||
|
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
|
||||||
|
if gpc.is_pipeline_first_stage():
|
||||||
|
input_tensor_grad = None
|
||||||
|
|
||||||
|
# Determine if peers are sending, and where in data structure to put
|
||||||
|
# received tensors.
|
||||||
|
recv_prev = True
|
||||||
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
||||||
|
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
||||||
|
next_forward_model_chunk_id = get_model_chunk_id(
|
||||||
|
forward_k - (pipeline_parallel_size - 1), forward=True)
|
||||||
|
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
||||||
|
recv_prev = False
|
||||||
|
next_forward_model_chunk_id += 1
|
||||||
|
else:
|
||||||
|
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
|
||||||
|
forward=True)
|
||||||
|
|
||||||
|
recv_next = True
|
||||||
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
||||||
|
next_backward_model_chunk_id = get_model_chunk_id(
|
||||||
|
backward_k - (pipeline_parallel_size - 1), forward=False)
|
||||||
|
if next_backward_model_chunk_id == 0:
|
||||||
|
recv_next = False
|
||||||
|
next_backward_model_chunk_id -= 1
|
||||||
|
else:
|
||||||
|
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
|
||||||
|
forward=False)
|
||||||
|
|
||||||
|
# If last iteration, don't receive; we already received one extra
|
||||||
|
# before the start of the for loop.
|
||||||
|
if k == (num_microbatches_remaining - 1):
|
||||||
|
recv_prev = False
|
||||||
|
|
||||||
|
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||||
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||||
|
# Communicate tensors.
|
||||||
|
input_tensor, output_tensor_grad = \
|
||||||
|
send_forward_backward_recv_forward_backward(
|
||||||
|
output_tensor, input_tensor_grad,
|
||||||
|
input_shape,
|
||||||
|
output_shape,
|
||||||
|
recv_prev=recv_prev, recv_next=recv_next,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
# Put input_tensor and output_tensor_grad in data structures in the
|
||||||
|
# right location.
|
||||||
|
if recv_prev:
|
||||||
|
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||||
|
if recv_next:
|
||||||
|
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||||
|
output_tensor_grad)
|
||||||
|
|
||||||
|
# Run cooldown backward passes (flush out pipeline).
|
||||||
|
if not forward_only:
|
||||||
|
if all_warmup_microbatches:
|
||||||
|
output_tensor_grads[num_model_chunks-1].append(
|
||||||
|
recv_backward(output_tensor_shapes[num_model_chunks-1]))
|
||||||
|
for k in range(num_microbatches_remaining, num_microbatches):
|
||||||
|
input_tensor_grad = backward_step_helper(k)
|
||||||
|
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
||||||
|
recv_next = True
|
||||||
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
||||||
|
recv_next = False
|
||||||
|
if k == (num_microbatches - 1):
|
||||||
|
recv_next = False
|
||||||
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||||
|
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||||
|
send_backward_recv_backward(
|
||||||
|
input_tensor_grad,
|
||||||
|
output_shape,
|
||||||
|
recv_next=recv_next,
|
||||||
|
dtype=self.dtype))
|
||||||
|
|
||||||
|
if len(return_tensors) > 0:
|
||||||
|
if return_loss:
|
||||||
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||||
|
return (torch.cat(output, dim=0),
|
||||||
|
torch.cat(label, dim=0),
|
||||||
|
sum(loss))
|
||||||
|
else:
|
||||||
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||||
|
else:
|
||||||
|
return tuple((None, None, None))
|
||||||
|
@ -280,11 +280,21 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
raise ConfigException(
|
raise ConfigException(
|
||||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||||
|
|
||||||
|
# clip grad norm
|
||||||
|
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||||
|
if clip_grad_norm > 0:
|
||||||
|
if zero_cfg is not None:
|
||||||
|
raise ConfigException(
|
||||||
|
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||||
|
|
||||||
# initialize amp
|
# initialize amp
|
||||||
amp_mode = None
|
amp_mode = None
|
||||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||||
|
# TODO: pipeline only support NAIVE AMP
|
||||||
cfg_ = fp16_cfg.copy()
|
cfg_ = fp16_cfg.copy()
|
||||||
amp_mode = cfg_.pop('mode')
|
amp_mode = cfg_.pop('mode')
|
||||||
|
if amp_mode == AMP_TYPE.NAIVE:
|
||||||
|
cfg_['clip_grad'] = clip_grad_norm
|
||||||
model, optimizer, criterion = convert_to_amp(model=model,
|
model, optimizer, criterion = convert_to_amp(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
@ -357,16 +367,6 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
gradient_handlers=gradient_handlers,
|
gradient_handlers=gradient_handlers,
|
||||||
lr_scheduler=lr_scheduler)
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
# clip grad norm
|
|
||||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
|
||||||
if clip_grad_norm > 0:
|
|
||||||
if zero_cfg is not None:
|
|
||||||
raise ConfigException(
|
|
||||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
|
||||||
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
|
|
||||||
raise ConfigException(
|
|
||||||
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
|
|
||||||
|
|
||||||
engine = Engine(
|
engine = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -3,7 +3,7 @@ from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
|
|||||||
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
|
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
|
||||||
is_using_pp, conditional_context, is_model_parallel_parameter,
|
is_using_pp, conditional_context, is_model_parallel_parameter,
|
||||||
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
|
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
|
||||||
param_is_not_tensor_parallel_duplicate)
|
param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank)
|
||||||
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
|
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
|
||||||
from .memory import report_memory_usage
|
from .memory import report_memory_usage
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
@ -22,5 +22,6 @@ __all__ = ['checkpoint',
|
|||||||
'Timer', 'MultiTimer',
|
'Timer', 'MultiTimer',
|
||||||
'multi_tensor_applier',
|
'multi_tensor_applier',
|
||||||
'accumulate_gradient',
|
'accumulate_gradient',
|
||||||
'DataParallelSampler', 'get_dataloader'
|
'DataParallelSampler', 'get_dataloader',
|
||||||
|
'switch_virtual_pipeline_parallel_rank'
|
||||||
]
|
]
|
||||||
|
@ -249,3 +249,13 @@ def param_is_not_tensor_parallel_duplicate(param):
|
|||||||
return (hasattr(param, IS_TENSOR_PARALLEL) and
|
return (hasattr(param, IS_TENSOR_PARALLEL) and
|
||||||
getattr(param, IS_TENSOR_PARALLEL)) or (
|
getattr(param, IS_TENSOR_PARALLEL)) or (
|
||||||
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
|
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def switch_virtual_pipeline_parallel_rank(rank):
|
||||||
|
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||||
|
try:
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||||
|
@ -172,10 +172,10 @@ elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
|
|||||||
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
|
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
|
||||||
`MODELS` registry. Define the `self.layers_cfg` attribute.
|
`MODELS` registry. Define the `self.layers_cfg` attribute.
|
||||||
Pass in a dict/Config object which specifies the parameters of your model.
|
Pass in a dict/Config object which specifies the parameters of your model.
|
||||||
Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers.
|
Use `colossalai.builder.pipeline.build_pipeline_model_from_cfg` to partition the layers.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.builder import PipelineModelInitializer
|
from colossalai.builder import build_pipeline_model_from_cfg
|
||||||
from colossalai.nn.model import ModelFromConfig
|
from colossalai.nn.model import ModelFromConfig
|
||||||
from colossalai.registry import MODELS
|
from colossalai.registry import MODELS
|
||||||
|
|
||||||
@ -199,8 +199,11 @@ model_cfg = dict(
|
|||||||
...
|
...
|
||||||
)
|
)
|
||||||
|
|
||||||
initializer = PipelineModelInitializer(model_cfg, num_chunks=1)
|
# from config
|
||||||
model = initializer.initialize()
|
model = build_pipeline_model_from_cfg(model_cfg, num_chunks=1)
|
||||||
|
|
||||||
|
# from torch.nn.Sequential
|
||||||
|
# model = build_pipeline_model(sequential_model, num_model_chunks)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -214,6 +217,9 @@ engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criteri
|
|||||||
|
|
||||||
schedule = PipelineSchedule(num_microbatches=4)
|
schedule = PipelineSchedule(num_microbatches=4)
|
||||||
|
|
||||||
|
# interleaved pipeline
|
||||||
|
# schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=2)
|
||||||
|
|
||||||
# execute a training epoch
|
# execute a training epoch
|
||||||
data_iter = iter(train_dataloader)
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
import colossalai
|
import colossalai
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from colossalai.builder import PipelineModelInitializer
|
from colossalai.builder import build_pipeline_model_from_cfg
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_dataloader, MultiTimer
|
from colossalai.utils import get_dataloader, MultiTimer
|
||||||
from colossalai.nn.loss import CrossEntropyLoss2D
|
from colossalai.nn.loss import CrossEntropyLoss2D
|
||||||
@ -50,8 +50,7 @@ def test_hybrid_parallel():
|
|||||||
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
|
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
|
||||||
|
|
||||||
# build vit-t-32
|
# build vit-t-32
|
||||||
initializer = PipelineModelInitializer(vit_t_2d.model_cfg, num_chunks=1)
|
model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1)
|
||||||
model = initializer.initialize()
|
|
||||||
|
|
||||||
# build dataloaders
|
# build dataloaders
|
||||||
train_dataset = CIFAR10(
|
train_dataset = CIFAR10(
|
||||||
@ -139,4 +138,4 @@ def test_hybrid_parallel():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
test_hybrid_parallel()
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.builder.pipeline import PipelineModelInitializer
|
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
||||||
from colossalai.core import global_context
|
from colossalai.core import global_context
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -28,7 +28,7 @@ def run_partition(rank, world_size):
|
|||||||
logger.info('finished initialization')
|
logger.info('finished initialization')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = PipelineModelInitializer(global_context.config.model, 1, verbose=True).initialize()
|
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
|
||||||
assert isinstance(model, torch.nn.Module)
|
assert isinstance(model, torch.nn.Module)
|
||||||
logger.info('model is created')
|
logger.info('model is created')
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import model
|
import model
|
||||||
|
|
||||||
from colossalai.builder import PipelineModelInitializer
|
from colossalai.builder import build_pipeline_model_from_cfg
|
||||||
from colossalai.communication import p2p as p2p_communication
|
from colossalai.communication import p2p as p2p_communication
|
||||||
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
@ -39,7 +39,7 @@ def run_schedule(rank, world_size):
|
|||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = PipelineModelInitializer(gpc.config.model, 1).initialize()
|
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
|
||||||
print_rank_0('model is created')
|
print_rank_0('model is created')
|
||||||
|
|
||||||
train_dataset = CIFAR10(
|
train_dataset = CIFAR10(
|
||||||
|
Loading…
Reference in New Issue
Block a user