mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)
This commit is contained in:
@@ -20,10 +20,16 @@ def convert_to_naive_amp(model: nn.Module,
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
if isinstance(model, nn.ModuleList):
|
||||
# interleaved pipeline
|
||||
module_list = []
|
||||
for chunk, m in enumerate(model):
|
||||
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||
model = nn.ModuleList(module_list)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
output_to_fp32 = is_no_pp_or_last_stage()
|
||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||
|
||||
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
|
||||
return model, optimizer
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
|
||||
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
|
||||
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier, is_using_pp)
|
||||
|
||||
|
||||
def _zero_grad_group_helper(group, set_to_none):
|
||||
@@ -58,7 +58,8 @@ class DynamicGradScaler:
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale: int = None):
|
||||
max_scale: int = None,
|
||||
verbose: bool = False):
|
||||
""""Grad scaler with dynamic scale that gets adjusted
|
||||
during training."""
|
||||
assert initial_scale > 0.0
|
||||
@@ -91,6 +92,7 @@ class DynamicGradScaler:
|
||||
self._hysteresis_tracker = self.hysteresis
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
self.verbose = verbose
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
@@ -111,7 +113,8 @@ class DynamicGradScaler:
|
||||
if self._hysteresis_tracker <= 0:
|
||||
self._scale = torch.max(self._scale * self.backoff_factor,
|
||||
self.min_scale)
|
||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
if self.verbose:
|
||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
else:
|
||||
# If there is no nan/inf, increment the growth tracker.
|
||||
self._growth_tracker += 1
|
||||
@@ -122,11 +125,14 @@ class DynamicGradScaler:
|
||||
self._hysteresis_tracker = self.hysteresis
|
||||
# and scale up the loss scale.
|
||||
if self._max_scale is not None and self._scale >= self._max_scale:
|
||||
self._logger.info(
|
||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
||||
if self.verbose:
|
||||
self._logger.info(
|
||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
||||
else:
|
||||
self._scale = self._scale * self.growth_factor
|
||||
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
if self.verbose:
|
||||
self._logger.info(
|
||||
f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
@@ -162,6 +168,8 @@ class FP16Optimizer(Optimizer):
|
||||
:type hysterisis: int
|
||||
:param max_scale: maximum loss scale allowed
|
||||
:type max_scale: int
|
||||
:param verbose: if set to `True`, will print debug info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -174,27 +182,29 @@ class FP16Optimizer(Optimizer):
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale: int = 2 ** 32):
|
||||
max_scale: int = 2 ** 32,
|
||||
verbose: bool = False):
|
||||
# default args for compatibility
|
||||
bf16 = False
|
||||
params_have_main_grad = True
|
||||
params_have_main_grad = False
|
||||
|
||||
# have a defaults for compatibility with pytorch optim
|
||||
self.defaults = optimizer.defaults
|
||||
|
||||
# log config
|
||||
self._logger = get_dist_logger()
|
||||
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||
f"clip_grad = {clip_grad}\n"
|
||||
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
|
||||
f"initial_scale = {initial_scale}\n"
|
||||
f"min_scale = {min_scale}\n"
|
||||
f"growth_factor = {growth_factor}\n"
|
||||
f"backoff_factor = {backoff_factor}\n"
|
||||
f"growth_interval = {growth_interval}\n"
|
||||
f"hysteresis = {hysteresis}\n"
|
||||
f"==========================================", ranks=[0])
|
||||
if verbose:
|
||||
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||
f"clip_grad = {clip_grad}\n"
|
||||
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
|
||||
f"initial_scale = {initial_scale}\n"
|
||||
f"min_scale = {min_scale}\n"
|
||||
f"growth_factor = {growth_factor}\n"
|
||||
f"backoff_factor = {backoff_factor}\n"
|
||||
f"growth_interval = {growth_interval}\n"
|
||||
f"hysteresis = {hysteresis}\n"
|
||||
f"==========================================", ranks=[0])
|
||||
|
||||
"""Input optimizer is the base optimizer for example Adam."""
|
||||
self.optimizer = optimizer
|
||||
@@ -212,7 +222,8 @@ class FP16Optimizer(Optimizer):
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale
|
||||
max_scale=max_scale,
|
||||
verbose=verbose
|
||||
)
|
||||
|
||||
# None grad scaler is only supported for bf16.
|
||||
@@ -350,6 +361,11 @@ class FP16Optimizer(Optimizer):
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
if is_using_pp():
|
||||
torch.distributed.all_reduce(self.found_inf,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
# Check for nan.
|
||||
found_inf_flag = (self.found_inf.item() > 0)
|
||||
return found_inf_flag
|
||||
|
Reference in New Issue
Block a user