mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
Optimize pipeline schedule (#94)
* add pipeline shared module wrapper and update load batch * added model parallel process group for amp and clip grad (#86) * added model parallel process group for amp and clip grad * update amp and clip with model parallel process group * remove pipeline_prev/next group (#88) * micro batch offload * optimize pipeline gpu memory usage * pipeline can receive tensor shape (#93) * optimize pipeline gpu memory usage * fix grad accumulation step counter * rename classes and functions Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
@@ -2,15 +2,12 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.builder import build_gradient_handler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import is_using_ddp, is_using_pp
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@@ -84,7 +81,7 @@ class Engine:
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
"""Start backward propagation given the loss value computed by a loss function
|
||||
|
||||
|
||||
:param loss: loss value computed by a loss function
|
||||
:type loss: :class:`torch.Tensor`
|
||||
"""
|
||||
@@ -92,7 +89,7 @@ class Engine:
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
"""Start backward propagation given the gradient of the output tensor
|
||||
|
||||
|
||||
:param loss: output tensor
|
||||
:type loss: :class:`torch.Tensor`
|
||||
:param grad: gradient passed back to the output
|
||||
|
@@ -1,5 +1,7 @@
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']
|
||||
|
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups.
|
||||
"""
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = defaultdict(lambda: defaultdict(list))
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
group = getattr(param, 'pipeline_shared_module_pg', None)
|
||||
if param.requires_grad and param.grad is not None and group is not None:
|
||||
tp = param.data.type()
|
||||
buckets[group][tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for group, group_buckets in buckets.items():
|
||||
for tp, bucket in group_buckets.items():
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
@@ -5,8 +5,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Iterable, Union, List, Callable
|
||||
from typing import Iterable, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
@@ -32,18 +31,17 @@ class BaseSchedule(ABC):
|
||||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (tuple, list)):
|
||||
data = tuple([self._move_tensor(d) for d in data])
|
||||
elif torch.is_tensor(data):
|
||||
data = data.to(get_current_device()).detach()
|
||||
if isinstance(data, dict):
|
||||
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
data = self._move_tensor(data)
|
||||
return data
|
||||
|
||||
def _to_list(self, data):
|
||||
if torch.is_tensor(data):
|
||||
return [data]
|
||||
return data
|
||||
@staticmethod
|
||||
def _check_sanity(data, tag):
|
||||
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
@@ -58,13 +56,17 @@ class BaseSchedule(ABC):
|
||||
data, label = self.batch_data_process_func(batch_data)
|
||||
else:
|
||||
data, label = batch_data
|
||||
|
||||
if isinstance(label, (tuple, list)):
|
||||
self.batch_size = label[0].size(0)
|
||||
self._check_sanity(data, 'data')
|
||||
self._check_sanity(label, 'label')
|
||||
if isinstance(data, torch.Tensor):
|
||||
self.batch_size = data.size(0)
|
||||
else:
|
||||
self.batch_size = label.size(0)
|
||||
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
self.batch_size = next(iter(data.values())).size(0)
|
||||
data, label = split_batch(data), split_batch(label)
|
||||
if to_gpu:
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return data, label
|
||||
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
@@ -76,7 +78,8 @@ class BaseSchedule(ABC):
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
@@ -85,5 +88,24 @@ class BaseSchedule(ABC):
|
||||
:param labels: ground truth
|
||||
:param forward_only: If True, the process won't include backward
|
||||
:param return_loss: If False, the loss won't be returned
|
||||
:param return_output_label: If False, the output and label won't be returned
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine, inputs):
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
else:
|
||||
return engine(**inputs)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine, outputs, labels):
|
||||
assert isinstance(outputs, (torch.Tensor, list, tuple)
|
||||
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs, )
|
||||
if isinstance(labels, torch.Tensor):
|
||||
return engine.criterion(*outputs, labels)
|
||||
else:
|
||||
return engine.criterion(*outputs, **labels)
|
||||
|
@@ -5,9 +5,7 @@ from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.engine import Engine
|
||||
from torch.optim import Optimizer
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
@@ -27,18 +25,21 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True):
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
:param engine: Model for training and inference
|
||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||
:param return_loss: Loss will be returned if True
|
||||
:param return_output_label: Output and label will be returned if True
|
||||
:type engine: Iterator
|
||||
:type data_iter: Iterator
|
||||
:type forward_only: bool, optional
|
||||
:type return_loss: bool, optional
|
||||
|
||||
:type return_output_label: bool, optional
|
||||
|
||||
:return: (output, label, loss)
|
||||
:rtype: Tuple[:class:`torch.Tensor`]
|
||||
"""
|
||||
@@ -48,16 +49,20 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
output = engine(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
output = self._call_engine(engine, data)
|
||||
if return_loss:
|
||||
loss = engine.criterion(*output, *label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
|
||||
if not forward_only:
|
||||
engine.backward(loss)
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
if return_output_label:
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
else:
|
||||
return output, label, None
|
||||
else:
|
||||
return output, None, None
|
||||
if return_loss:
|
||||
return None, None, loss
|
||||
else:
|
||||
return None, None, None
|
||||
|
@@ -1,19 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Union
|
||||
|
||||
from typing import List, Tuple, Union, Callable
|
||||
import inspect
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication import *
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
@@ -30,102 +30,79 @@ class PipelineSchedule(BaseSchedule):
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
||||
:param num_microbatches: The number of microbatches
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
|
||||
:type num_microbatches: int
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
:type sync_data: bool
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
sync_data: bool = True):
|
||||
super().__init__()
|
||||
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.sync_data = sync_data
|
||||
self.dtype = torch.float
|
||||
self.tensor_shape = tensor_shape
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
list,
|
||||
)):
|
||||
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
|
||||
data = data[0]
|
||||
assert torch.is_tensor(data), "Data in pipeline should be tensor"
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def _sync_data(self):
|
||||
reqs = []
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_global_rank()
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
||||
async_op=True
|
||||
))
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
||||
async_op=True
|
||||
))
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
||||
async_op=True
|
||||
))
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
||||
async_op=True
|
||||
))
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# Pipeline schedule just puts data in memory
|
||||
def load_batch(self, data_iter):
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
self.batch_pos = 0
|
||||
data, label = next(data_iter)
|
||||
self.batch_data, self.batch_label = \
|
||||
self._move_to_device(data), self._move_to_device(label)
|
||||
batch_size = self.batch_data.shape[0]
|
||||
assert batch_size % self.num_microbatches == 0, \
|
||||
# Pipeline schedule just puts data in memory
|
||||
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = batch_size // self.num_microbatches
|
||||
if self.sync_data:
|
||||
self._sync_data()
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def _get_data_slice(self, tensor):
|
||||
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
|
||||
def _get_data_slice(self, data, offset):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data[offset: offset + self.microbatch_size]
|
||||
else:
|
||||
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||
|
||||
def load_micro_batch(self):
|
||||
data = self._get_data_slice(self.batch_data)
|
||||
label = self._get_data_slice(self.batch_label)
|
||||
self.batch_pos += self.microbatch_size
|
||||
return (data,), (label,)
|
||||
data = self._get_data_slice(self.batch_data, self.microbatch_offset)
|
||||
label = self._get_data_slice(self.batch_label, self.microbatch_offset)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
)
|
||||
|
||||
if isinstance(engine.model, NaiveAMPModel):
|
||||
model = engine.model
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
self.dtype = torch.half
|
||||
model = model.model
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
||||
@staticmethod
|
||||
def _call_engine(model, input_tensor, batch_data):
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
sig = inspect.signature(model.model.forward)
|
||||
else:
|
||||
sig = inspect.signature(model.forward)
|
||||
if isinstance(batch_data, torch.Tensor):
|
||||
if input_tensor is None:
|
||||
return model(batch_data)
|
||||
elif len(sig.parameters) > 1:
|
||||
return model(input_tensor, batch_data)
|
||||
else:
|
||||
return model(input_tensor)
|
||||
else:
|
||||
filter_batch = True
|
||||
for p in sig.parameters.values():
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
filter_batch = False
|
||||
if filter_batch:
|
||||
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||
if input_tensor is None:
|
||||
return model(**batch_data)
|
||||
else:
|
||||
return model(input_tensor, **batch_data)
|
||||
|
||||
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
@@ -140,26 +117,19 @@ class PipelineSchedule(BaseSchedule):
|
||||
:return: output or the loss value of the current pipeline stage
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = engine(input_tensor)
|
||||
data, label = self.load_micro_batch()
|
||||
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = engine.criterion(output_tensor, *label) \
|
||||
/ self.num_microbatches
|
||||
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
if return_output_label:
|
||||
return_tensors.append(tuple((output_tensor, label)))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
return_tensors.append(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
@@ -203,7 +173,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
engine,
|
||||
data_iter,
|
||||
forward_only=False,
|
||||
return_loss=True):
|
||||
return_loss=True,
|
||||
return_output_label=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
@@ -215,6 +186,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
:type forward_only: bool
|
||||
:param return_loss: whether returns the loss value. Default is true.
|
||||
:type return_loss: bool
|
||||
:param return_output_label: If False, the output and label won't be returned
|
||||
:type return_output_label: bool
|
||||
|
||||
:return: (output, label, loss)
|
||||
:rtype: Tuple[:class:`torch.Tensor`]
|
||||
@@ -238,11 +211,14 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
return_tensors = []
|
||||
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
# Used for tensor meta information communication
|
||||
ft_shape = None
|
||||
ft_shape = self.tensor_shape
|
||||
bt_shape = None
|
||||
fs_checker = True
|
||||
fs_checker = self.tensor_shape is None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
@@ -251,7 +227,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
output_tensor = self.forward_step(
|
||||
engine, input_tensor, return_tensors,
|
||||
return_loss=return_loss
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
@@ -276,7 +253,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
output_tensor = self.forward_step(
|
||||
engine, input_tensor, return_tensors,
|
||||
return_loss=return_loss
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
@@ -327,24 +305,37 @@ class PipelineSchedule(BaseSchedule):
|
||||
send_backward(input_tensor_grad)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
if return_loss:
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
accum_loss)
|
||||
else:
|
||||
return tuple((None, None, None))
|
||||
return tuple((None, None, accum_loss))
|
||||
|
||||
|
||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
num_model_chunks,
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
||||
:param num_microbatches: The number of microbatches
|
||||
:type num_microbatches: int
|
||||
:param num_model_chunks: The number of model chunks
|
||||
:type num_model_chunks: int
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
"""
|
||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||
super().__init__(num_microbatches, sync_data=sync_data)
|
||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
def pre_processing(self, engine):
|
||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
@@ -355,32 +346,46 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
if isinstance(engine.model[0], NaiveAMPModel):
|
||||
self.dtype = torch.half
|
||||
|
||||
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
|
||||
for model in engine.model:
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
model = model.model
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
super().load_batch(data_iter)
|
||||
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
|
||||
def load_micro_batch(self, model_chunk_id):
|
||||
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
|
||||
label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id])
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
"""
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = model(input_tensor)
|
||||
data, label = self.load_micro_batch(model_chunk_id)
|
||||
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
if return_output_label:
|
||||
return_tensors.append(tuple(output_tensor, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
return_tensors.append(output_tensor)
|
||||
return output_tensor
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
@@ -394,11 +399,15 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
return_tensors = []
|
||||
if not forward_only:
|
||||
output_tensor_grads = [[] for _ in range(len(model))]
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# Used for tensor meta information communication
|
||||
input_tensor_shapes = [None for _ in range(len(model))]
|
||||
input_tensor_shapes = [self.tensor_shape for _ in range(len(model))]
|
||||
output_tensor_shapes = [None for _ in range(len(model))]
|
||||
send_tensor_shape_flags = [True for _ in range(len(model))]
|
||||
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
|
||||
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
@@ -450,8 +459,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
len(output_tensors[model_chunk_id]):
|
||||
input_tensors[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id][-1]
|
||||
output_tensor = self.forward_step(
|
||||
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
|
||||
output_tensor = self.forward_step(engine, model_chunk_id, input_tensor,
|
||||
return_tensors, return_output_label=return_output_label, accum_loss=accum_loss)
|
||||
output_tensors[model_chunk_id].append(output_tensor)
|
||||
|
||||
# if forward-only, no need to save tensors for a backward pass
|
||||
@@ -633,12 +642,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
dtype=self.dtype))
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
if return_loss:
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
accum_loss)
|
||||
else:
|
||||
return tuple((None, None, None))
|
||||
return tuple((None, None, accum_loss))
|
||||
|
Reference in New Issue
Block a user