Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from .amp_type import AMP_TYPE
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
__all__ = ['Engine']

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .schedule import BaseSchedule, NoPipelineSchedule
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
:param schedule: Running schedule in :meth:`step`
:type train_dataloader: DataLoader, optional
:type test_dataloader: DataLoader, optional
:type model: Module
:type criterion: _Loss, optional
:type optimizer: Optimizer, optional
:type lr_scheduler: _LRScheduler, optional
:type schedule: BaseSchedule, optional
"""
def __init__(self,
train_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
model: Module = None,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
self._logger = get_global_dist_logger()
# build gradient handler
self._gradient_handlers = []
gradient_handler_cfg = []
if hasattr(gpc.config, 'gradient_handler'):
assert isinstance(gpc.config.gradient_handler, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}'
gradient_handler_cfg = gpc.config.gradient_handler
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if len(gradient_handler_cfg) == 0:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
for cfg in gradient_handler_cfg:
handler = build_gradient_handler(cfg, self.model, self.optimizer)
self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def set_dataloader(self, data: DataLoader, train: bool = True):
"""Sets dataloader in training or evaluation.
:param data: Dataloader to be set
:param train: Set training dataloader if True, otherwise evaluation dataloader
:type data: DataLoader
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data
def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer
def get_lr_scheduler(self):
"""Returns the learning rate scheduler in the engine.
"""
return self.lr_scheduler
def train(self):
"""Sets the model to training mode.
"""
self.forward_only = False
self.schedule.train(dataloader=self.train_dataloader, mode=True)
def eval(self):
"""Sets the model to evaluation mode.
"""
self.forward_only = True
self.schedule.train(dataloader=self.test_dataloader, mode=False)
def is_train(self):
"""Returns True if it is in training, otherwise False.
"""
return not self.forward_only
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param return_loss: loss will be returned if True
:type return_loss: bool
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)
output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)
if not self.forward_only:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
return output, label, loss

View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
PARALLEL = 'parallel'

View File

@@ -0,0 +1,5 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.
:param model: Model where the gradients accumulate
:param optimizer: Optimizer for updating the parameters
:type model: Module
:type optimizer: Optimizer
"""
def __init__(self, model, optimizer):
self._model = model
self._optimizer = optimizer
@abstractmethod
def handle_gradient(self):
"""A method to accumulate gradients across different parallel groups. Users should
write their own functions or just use the functions in pre-defined subclasses.
"""
pass

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# TODO: add memory buffer
if gpc.data_parallel_size > 1:
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.DATA)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.DATA))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)

View File

@@ -0,0 +1,16 @@
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class ZeROGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
self._optimizer.allreduce_gradients()

View File

@@ -0,0 +1,5 @@
from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule
from ._pipeline import PipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
"""
pass

View File

@@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
print('apex is required for mixed precision training')
try:
import torch.cuda.amp as torch_amp
except:
print('PyTorch amp is not supported with the current PyTorch version')
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16
from ._base_schedule import BaseSchedule
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch()
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
if not forward_only:
# backward
if self.use_zero_level_2_3:
self.optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss,
self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss
else:
return output, None, None
def step(self):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer)
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

View File

@@ -0,0 +1,316 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union
import torch.cuda
import torch.distributed as dist
from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
if isinstance(x, (tuple, list)):
return x[0]
else:
return x
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(self,
num_microbatches,
amp_type: AMP_TYPE = None,
amp_config: dict = None):
super().__init__()
self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
data = data[0]
assert torch.is_tensor(data), "Data in pipeline should be tensor"
data = data.to(get_current_device()).detach()
return data
def _sync_data(self):
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
# Pipeline schedule just puts data in memory
def load_batch(self):
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0
data, label = next(self.data_iter)
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync:
self._sync_data()
def _get_data_slice(self, tensor):
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
def load_micro_batch(self):
data = self._get_data_slice(self.batch_data)
label = self._get_data_slice(self.batch_label)
self.batch_pos += self.microbatch_size
return (data,), (label,)
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL:
torch.set_default_dtype(torch.half)
self.logger.info(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
"""
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
else:
return_tensors.append(output_tensor)
return output_tensor
else:
return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
"""
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch()
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(num_warmup_microbatches,
self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
return_tensors = []
# Used for tensor meta information communication
ft_shape = None
bt_shape = None
fs_checker = True
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
send_forward(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if forward_only:
send_forward(output_tensor)
if not last_iteration:
input_tensor = recv_forward(ft_shape)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad)
if len(return_tensors) > 0:
if return_loss:
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))

View File

@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret