mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
Migrated project
This commit is contained in:
7
colossalai/engine/__init__.py
Normal file
7
colossalai/engine/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .amp_type import AMP_TYPE
|
||||
from ._base_engine import Engine
|
||||
from .gradient_handler import *
|
||||
from .schedule import *
|
||||
|
||||
|
||||
__all__ = ['Engine']
|
170
colossalai/engine/_base_engine.py
Normal file
170
colossalai/engine/_base_engine.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.builder import build_gradient_handler
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .schedule import BaseSchedule, NoPipelineSchedule
|
||||
|
||||
|
||||
class Engine:
|
||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||||
|
||||
:param train_dataloader: Dataloader in training
|
||||
:param test_dataloader: Dataloader in evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
||||
:param schedule: Running schedule in :meth:`step`
|
||||
:type train_dataloader: DataLoader, optional
|
||||
:type test_dataloader: DataLoader, optional
|
||||
:type model: Module
|
||||
:type criterion: _Loss, optional
|
||||
:type optimizer: Optimizer, optional
|
||||
:type lr_scheduler: _LRScheduler, optional
|
||||
:type schedule: BaseSchedule, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
test_dataloader: Optional[DataLoader] = None,
|
||||
model: Module = None,
|
||||
criterion: _Loss = None,
|
||||
optimizer: Optimizer = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
schedule: BaseSchedule = None):
|
||||
self.train_dataloader = train_dataloader
|
||||
self.test_dataloader = test_dataloader
|
||||
assert model is not None, "Engine requires a model"
|
||||
self.model = model
|
||||
self.criterion = criterion
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.schedule = schedule if schedule is not None \
|
||||
else NoPipelineSchedule()
|
||||
self._logger = get_global_dist_logger()
|
||||
|
||||
# build gradient handler
|
||||
self._gradient_handlers = []
|
||||
gradient_handler_cfg = []
|
||||
|
||||
if hasattr(gpc.config, 'gradient_handler'):
|
||||
assert isinstance(gpc.config.gradient_handler, list), \
|
||||
f'argument gradient_handler_cfg expected type list, ' \
|
||||
f'but got type {type(gpc.config.gradient_handler)}'
|
||||
gradient_handler_cfg = gpc.config.gradient_handler
|
||||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
self._logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||||
ParallelMode.DATA) > 1:
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
self._logger.info(
|
||||
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
if len(gradient_handler_cfg) == 0:
|
||||
self._logger.warning(
|
||||
"No gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
for cfg in gradient_handler_cfg:
|
||||
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
||||
self._gradient_handlers.append(handler)
|
||||
|
||||
self.schedule.initialize(self.train_dataloader, self.model,
|
||||
self.criterion, self.optimizer,
|
||||
self.lr_scheduler)
|
||||
self.forward_only = False
|
||||
|
||||
def handle_gradient(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||
"""
|
||||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
def set_dataloader(self, data: DataLoader, train: bool = True):
|
||||
"""Sets dataloader in training or evaluation.
|
||||
|
||||
:param data: Dataloader to be set
|
||||
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
||||
:type data: DataLoader
|
||||
:type train: bool
|
||||
"""
|
||||
if train:
|
||||
self.train_dataloader = data
|
||||
else:
|
||||
self.test_dataloader = data
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the neural network model in the engine.
|
||||
"""
|
||||
return self.model
|
||||
def get_optimizer(self):
|
||||
"""Returns optimizier in the engine.
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
"""Returns the learning rate scheduler in the engine.
|
||||
"""
|
||||
return self.lr_scheduler
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode.
|
||||
"""
|
||||
self.forward_only = False
|
||||
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
||||
|
||||
def eval(self):
|
||||
"""Sets the model to evaluation mode.
|
||||
"""
|
||||
self.forward_only = True
|
||||
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
||||
|
||||
def is_train(self):
|
||||
"""Returns True if it is in training, otherwise False.
|
||||
"""
|
||||
return not self.forward_only
|
||||
|
||||
def get_lr(self):
|
||||
"""Gets current learning rate.
|
||||
"""
|
||||
return self.schedule.get_lr()
|
||||
|
||||
def step(self, return_loss=True):
|
||||
"""A running step based on the schedule. Usually, it runs a training or
|
||||
evaluation over a batch of dataset.
|
||||
|
||||
:param return_loss: loss will be returned if True
|
||||
:type return_loss: bool
|
||||
:return: (output, lablel, loss)
|
||||
"""
|
||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
||||
|
||||
output, label, loss = self.schedule.forward_backward_step(
|
||||
forward_only=self.forward_only, return_loss=return_loss)
|
||||
|
||||
if not self.forward_only:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
|
||||
self.schedule.step()
|
||||
|
||||
return output, label, loss
|
10
colossalai/engine/amp_type.py
Normal file
10
colossalai/engine/amp_type.py
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AMP_TYPE(Enum):
|
||||
APEX = 'apex'
|
||||
TORCH = 'torch'
|
||||
PARALLEL = 'parallel'
|
5
colossalai/engine/gradient_handler/__init__.py
Normal file
5
colossalai/engine/gradient_handler/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']
|
25
colossalai/engine/gradient_handler/_base_gradient_handler.py
Normal file
25
colossalai/engine/gradient_handler/_base_gradient_handler.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
:param model: Model where the gradients accumulate
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:type model: Module
|
||||
:type optimizer: Optimizer
|
||||
"""
|
||||
def __init__(self, model, optimizer):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
|
||||
@abstractmethod
|
||||
def handle_gradient(self):
|
||||
"""A method to accumulate gradients across different parallel groups. Users should
|
||||
write their own functions or just use the functions in pre-defined subclasses.
|
||||
"""
|
||||
pass
|
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
# TODO: add memory buffer
|
||||
if gpc.data_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
param.main_grad = param.grad
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
16
colossalai/engine/gradient_handler/_zero_gradient_handler.py
Normal file
16
colossalai/engine/gradient_handler/_zero_gradient_handler.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class ZeROGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
This class is specialized with ZeRO optimization.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
self._optimizer.allreduce_gradients()
|
5
colossalai/engine/schedule/__init__.py
Normal file
5
colossalai/engine/schedule/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._no_pipeline import NoPipelineSchedule
|
||||
from ._pipeline import PipelineSchedule
|
||||
|
||||
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
|
129
colossalai/engine/schedule/_base_schedule.py
Normal file
129
colossalai/engine/schedule/_base_schedule.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.logger = get_global_dist_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_steps(self):
|
||||
"""The number of batches in training or evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def initialize(self,
|
||||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None):
|
||||
"""Initializes the schedule and set parameters before running.
|
||||
|
||||
:param dataloader: DataLoader in training or evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler in the process
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
assert model is not None, "Schedule requires a model"
|
||||
self.model = model
|
||||
assert criterion is not None, "Schedule requires a criterion"
|
||||
self.criterion = criterion
|
||||
assert optimizer is not None, "Schedule requires an optimizer"
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.initialized = True
|
||||
|
||||
def check_initialized(self):
|
||||
"""Checks whether the schedule is initialized.
|
||||
"""
|
||||
assert self.initialized, \
|
||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
||||
|
||||
def load_batch(self):
|
||||
"""Loads a batch of dataset. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
:return: (data, label)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(self.data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
list,
|
||||
)):
|
||||
data = tuple([
|
||||
d.to(get_current_device()).detach() for d in data
|
||||
if torch.is_tensor(d)
|
||||
])
|
||||
elif torch.is_tensor(data):
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def train(self, dataloader=None, mode=True):
|
||||
"""Sets the dataloader to be used and turn the model to
|
||||
training or evaluation mode.
|
||||
|
||||
:param dataloader: Dataloader to be used
|
||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||||
"""
|
||||
self.check_initialized()
|
||||
if mode:
|
||||
self.model.train()
|
||||
else:
|
||||
self.model.eval()
|
||||
if dataloader is not None:
|
||||
self.dataloader = dataloader
|
||||
self.data_iter = iter(dataloader)
|
||||
|
||||
def zero_grad(self, forward_only=False):
|
||||
"""Cleans gradients with the optimizer.
|
||||
"""
|
||||
if not forward_only:
|
||||
self.check_initialized()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def get_lr(self):
|
||||
"""Returns the current learning rate.
|
||||
"""
|
||||
if self.lr_scheduler is not None:
|
||||
return self.lr_scheduler.get_lr()[0]
|
||||
else:
|
||||
return self.optimizer.param_groups[0]['lr']
|
||||
|
||||
def step(self):
|
||||
"""Updates the parameters and learning rate with the optimizer.
|
||||
"""
|
||||
self.check_initialized()
|
||||
self.optimizer.step()
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
:param forward_only: If True, the process won't include backward.
|
||||
:param return_loss: If False, the loss won't be returned.
|
||||
"""
|
||||
pass
|
185
colossalai/engine/schedule/_no_pipeline.py
Normal file
185
colossalai/engine/schedule/_no_pipeline.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except:
|
||||
print('apex is required for mixed precision training')
|
||||
try:
|
||||
import torch.cuda.amp as torch_amp
|
||||
except:
|
||||
print('PyTorch amp is not supported with the current PyTorch version')
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.amp_type import AMP_TYPE
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from ._utils import convert_to_fp16
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
class NoPipelineSchedule(BaseSchedule):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
amp_type: AMP_TYPE = None,
|
||||
amp_config: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# mixed precision training
|
||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||
|
||||
# LSG: check compatibility
|
||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
||||
ParallelMode.TENSOR) > 1:
|
||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
||||
self.use_zero_level_2_3 = False
|
||||
|
||||
if amp_type is not None:
|
||||
self.fp16 = True
|
||||
self.amp_type = amp_type
|
||||
|
||||
if amp_config is not None:
|
||||
assert isinstance(amp_config, dict), \
|
||||
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
|
||||
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
# torch apex
|
||||
if amp_config is None:
|
||||
amp_config = dict()
|
||||
self.amp_cfg = amp_config
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
# apex amp
|
||||
if amp_config is None:
|
||||
amp_config = dict(opt_level='O2')
|
||||
self.logger.warning(
|
||||
'apex is deprecated, please consider using torch.cuda.amp instead.'
|
||||
)
|
||||
self.amp_cfg = amp_config
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
# use fp16 optimizer for tensor parallel training
|
||||
if amp_config is None:
|
||||
amp_config = dict()
|
||||
self.amp_cfg = amp_config
|
||||
else:
|
||||
self.fp16 = False
|
||||
self.amp_type = None
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
self.use_zero_level_2_3 = True
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
|
||||
if self.fp16:
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
self.model, self.optimizer = apex_amp.initialize(
|
||||
self.model, self.optimizer, **self.amp_cfg)
|
||||
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
data, label = self.load_batch()
|
||||
loss = None
|
||||
|
||||
# LSG: leave for debug, make sure dataloader is deterministic
|
||||
# if forward_only:
|
||||
# img = data[0]
|
||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
# group = gpc.get_group(ParallelMode.DATA)
|
||||
# input_list = [img.clone() for _ in range(world_size)]
|
||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
||||
# output_list[rank] = img.clone()
|
||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
||||
|
||||
# forward
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
with torch_amp.autocast():
|
||||
output = self.model(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
else:
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
data = convert_to_fp16(data)
|
||||
|
||||
output = self.model(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
|
||||
if not forward_only:
|
||||
# backward
|
||||
if self.use_zero_level_2_3:
|
||||
self.optimizer.backward(loss)
|
||||
elif self.fp16:
|
||||
if self.amp_type == AMP_TYPE.APEX:
|
||||
with apex_amp.scale_loss(loss,
|
||||
self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.scale(loss).backward()
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
loss = self.optimizer.scale_loss(loss)
|
||||
loss.backward()
|
||||
# scale back to display the original value in logs
|
||||
loss.div_(self.optimizer.grad_scaler.scale)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
else:
|
||||
return output, None, None
|
||||
|
||||
def step(self):
|
||||
# step optimizer
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.step(self.optimizer)
|
||||
self._torch_amp_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
316
colossalai/engine/schedule/_pipeline.py
Normal file
316
colossalai/engine/schedule/_pipeline.py
Normal file
@@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication import *
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.utils import get_current_device
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16
|
||||
from ..amp_type import AMP_TYPE
|
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return x[0]
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class PipelineSchedule(BaseSchedule):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NoPipelineSchedule`.
|
||||
|
||||
:param num_microbatches: The number of microbatches
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:type num_microbatches: int
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
amp_type: AMP_TYPE = None,
|
||||
amp_config: dict = None):
|
||||
super().__init__()
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.data_sync = True # close after making sure data is identical
|
||||
|
||||
# amp
|
||||
# LSGL: amp_config is not used, but leave here for future extension
|
||||
self.amp_type = amp_type
|
||||
self.amp_config = amp_config
|
||||
|
||||
if self.amp_type is not None:
|
||||
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
list,
|
||||
)):
|
||||
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
|
||||
data = data[0]
|
||||
assert torch.is_tensor(data), "Data in pipeline should be tensor"
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def _sync_data(self):
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_global_rank()
|
||||
dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
)
|
||||
dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
)
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
)
|
||||
dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
)
|
||||
|
||||
# Pipeline schedule just puts data in memory
|
||||
def load_batch(self):
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
self.batch_pos = 0
|
||||
data, label = next(self.data_iter)
|
||||
self.batch_data, self.batch_label = \
|
||||
self._move_to_device(data), self._move_to_device(label)
|
||||
batch_size = self.batch_data.shape[0]
|
||||
assert batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = batch_size // self.num_microbatches
|
||||
if self.data_sync:
|
||||
self._sync_data()
|
||||
|
||||
def _get_data_slice(self, tensor):
|
||||
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
|
||||
|
||||
def load_micro_batch(self):
|
||||
data = self._get_data_slice(self.batch_data)
|
||||
label = self._get_data_slice(self.batch_label)
|
||||
self.batch_pos += self.microbatch_size
|
||||
return (data,), (label,)
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
)
|
||||
|
||||
# LSG: set default dtype to fp16 for communication
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
torch.set_default_dtype(torch.half)
|
||||
self.logger.info(
|
||||
'default tensor dtype is set to torch.half for fp16 training',
|
||||
ranks=[0])
|
||||
|
||||
def forward_step(self, input_tensor, return_tensors, return_loss=True):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
"""
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
input_tensor = convert_to_fp16(input_tensor)
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = self.model(input_tensor)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = self.criterion(output_tensor, *
|
||||
label) / self.num_microbatches
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
return loss_reduced
|
||||
else:
|
||||
return_tensors.append(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
This is a helper function and can be ignored by users.
|
||||
"""
|
||||
|
||||
# Retain the grad on the input_tensor.
|
||||
if input_tensor is not None:
|
||||
input_tensor.retain_grad()
|
||||
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output_tensor = self.optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
if input_tensor is not None:
|
||||
input_tensor_grad = input_tensor.grad
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
def forward_backward_step(self, forward_only=True, return_loss=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
self.load_batch()
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches,
|
||||
self.num_microbatches)
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_tensors = None
|
||||
output_tensors = None
|
||||
if not forward_only:
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
return_tensors = []
|
||||
|
||||
# Used for tensor meta information communication
|
||||
ft_shape = None
|
||||
bt_shape = None
|
||||
fs_checker = True
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
send_forward(output_tensor)
|
||||
|
||||
if not forward_only:
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
|
||||
if not last_iteration:
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
|
||||
else:
|
||||
output_tensor_grad = send_forward_recv_backward(
|
||||
output_tensor, bt_shape)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Pop input_tensor and output_tensor from the start of the list for
|
||||
# the backward pass.
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
send_backward(input_tensor_grad)
|
||||
else:
|
||||
input_tensor = send_backward_recv_forward(
|
||||
input_tensor_grad, ft_shape)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = recv_backward(bt_shape)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
send_backward(input_tensor_grad)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
if return_loss:
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
else:
|
||||
return tuple((None, None, None))
|
16
colossalai/engine/schedule/_utils.py
Normal file
16
colossalai/engine/schedule/_utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.half()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.half() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
Reference in New Issue
Block a user