Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
import torch
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggreate all codes that contain the
control of FP16 in class schedule.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
@staticmethod
def _move_tensor(element):
if torch.is_tensor(element):
if not element.is_cuda:
return element.to(get_current_device()).detach()
return element
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
if isinstance(data, (tuple, list)):
data = tuple([self._move_tensor(d) for d in data])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(data_iter)
return self._move_to_device(data), self._move_to_device(label)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def initialize(self, model, optimizer):
"""Initializes the model and the optimizer before training.
This is often used in FP16 training.
def get_lr(self):
"""Returns the current learning rate.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return model, optimizer
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
:param data_iter: Data iterator of the dataset
:param model: Model used in training or evaluation
:param optimizer: Optimizer used in training or evaluation
:param criterion: Loss function
:param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
"""
pass
@abstractmethod
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""Updates the parameters with the optimizer.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
"""
pass

View File

@@ -4,19 +4,24 @@
try:
import apex.amp as apex_amp
except:
print('apex is required for mixed precision training')
pass
try:
import torch.cuda.amp as torch_amp
except:
print('PyTorch amp is not supported with the current PyTorch version')
pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule):
@@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
@@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
@@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
self.fp16 = False
self.amp_type = None
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
def initialize(self, model: nn.Module, optimizer: Optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
def forward_backward_step(self, forward_only=False, return_loss=True):
return model, optimizer
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch()
data, label = self.load_batch(data_iter)
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = self.model(*data)
output = model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss = criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = self.model(*data)
output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only:
# backward
if self.use_zero_level_2_3:
self.optimizer.backward(loss)
optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss,
self.optimizer) as scaled_loss:
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss)
loss = optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale)
loss.div_(optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss
return output, label, loss * grad_accum_size
else:
return output, None, None
def step(self):
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer)
if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
optimizer.step()

View File

@@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
from ..amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
@@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
)
# Pipeline schedule just puts data in memory
def load_batch(self):
self.check_initialized()
if self.data_iter is None:
def load_batch(self, data_iter):
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0
data, label = next(self.data_iter)
data, label = next(data_iter)
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
@@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size
return (data,), (label,)
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
def initialize(self, model, optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
@@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True):
def forward_step(self, model, criterion, input_tensor, return_tensors,
grad_accum_size, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor)
output_tensor = model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
loss_reduced = criterion(output_tensor, *label) \
/ (self.num_microbatches * grad_accum_size)
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
else:
return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor)
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
@@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True):
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch()
self.load_batch(data_iter)
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
@@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
@@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if forward_only:
send_forward(output_tensor)
@@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
input_tensor_grad = self.backward_step(
optimizer,
input_tensor, output_tensor,
output_tensor_grad
)
if last_iteration:
input_tensor = None
@@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
input_tensor_grad = self.backward_step(
optimizer,
input_tensor, output_tensor,
output_tensor_grad
)
send_backward(input_tensor_grad)
@@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
sum(loss) * grad_accum_size)
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()

View File

@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret