mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-22 15:26:57 +00:00 
			
		
		
		
	* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> * Split conv2d, class token, positional embedding in 2d, Fix random number in ddp Fix convergence in cifar10, Imagenet1000 * Integrate 1d tensor parallel in Colossal-AI (#39) * fixed 1D and 2D convergence (#38) * optimized 2D operations * fixed 1D ViT convergence problem * Feature/ddp (#49) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * support torch ddp * fix loss accumulation * add log for ddp * change seed * modify timing hook Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * Feature/pipeline (#40) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * optimize communication of pipeline parallel * fix grad clip for pipeline Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51) * Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset * update api for better usability (#58) update api for better usability Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
		
			
				
	
	
		
			2348 lines
		
	
	
		
			99 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2348 lines
		
	
	
		
			99 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| '''
 | |
| Copyright 2019 The Microsoft DeepSpeed Team
 | |
| '''
 | |
| 
 | |
| import math
 | |
| 
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| 
 | |
| try:
 | |
|     from deepspeed.git_version_info import version
 | |
|     from deepspeed.moe.utils import is_moe_param
 | |
|     from deepspeed.ops.adam import DeepSpeedCPUAdam
 | |
|     from deepspeed.ops.op_builder import UtilsBuilder
 | |
|     from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
 | |
| except ImportError:
 | |
|     pass
 | |
| from packaging import version as pkg_version
 | |
| from torch._six import inf
 | |
| from torch.distributed.distributed_c10d import _get_global_rank
 | |
| from torch.optim import Optimizer
 | |
| 
 | |
| from colossalai.core import global_context as gpc
 | |
| from colossalai.utils import report_memory_usage
 | |
| from colossalai.utils.common import is_model_parallel_parameter
 | |
| from .loss_scaler import LossScaler, DynamicLossScaler
 | |
| from colossalai.context import ParallelMode
 | |
| 
 | |
| # Toggle this to true to enable correctness test
 | |
| # with gradient partitioning and without
 | |
| pg_correctness_test = False
 | |
| 
 | |
| 
 | |
| def input(msg):
 | |
|     return
 | |
| 
 | |
| 
 | |
| def split_half_float_double(tensors):
 | |
|     dtypes = [
 | |
|         "torch.cuda.HalfTensor",
 | |
|         "torch.cuda.FloatTensor",
 | |
|         "torch.cuda.DoubleTensor"
 | |
|     ]
 | |
|     buckets = []
 | |
|     for i, dtype in enumerate(dtypes):
 | |
|         bucket = [t for t in tensors if t.type() == dtype]
 | |
|         if bucket:
 | |
|             buckets.append(bucket)
 | |
|     return buckets
 | |
| 
 | |
| 
 | |
| def isclose(a, b, rtol=1e-09, atol=0.0):
 | |
|     return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
 | |
| 
 | |
| 
 | |
| def lcm(x, y):
 | |
|     from fractions import gcd  # or can import gcd from `math` in Python 3
 | |
|     return x * y // gcd(x, y)
 | |
| 
 | |
| 
 | |
| def get_alignment_padding(tensor_list, alignment):
 | |
|     num_elements = sum([tensor.numel() for tensor in tensor_list])
 | |
|     remainder = num_elements % alignment
 | |
|     return (alignment - remainder) if remainder else remainder
 | |
| 
 | |
| 
 | |
| def move_to_cpu(tensor_list):
 | |
|     for tensor in tensor_list:
 | |
|         tensor.data = tensor.data.cpu()
 | |
| 
 | |
| 
 | |
| def print_rank_msg(msg):
 | |
|     print(f"rank {dist.get_rank()} - {msg}")
 | |
| 
 | |
| 
 | |
| class ZeroRedundancyOptimizer_Level_2(Optimizer):
 | |
|     """
 | |
|     ZeroRedundancyOptimizer_Level_2 designed to reduce the memory footprint
 | |
|     required for training large deep learning models.
 | |
| 
 | |
|     For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
 | |
|     https://arxiv.org/abs/1910.02054
 | |
| 
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  init_optimizer,
 | |
|                  dp_parallel_mode=ParallelMode.DATA,
 | |
|                  static_loss_scale=1.0,
 | |
|                  dynamic_loss_scale=False,
 | |
|                  dynamic_loss_args=None,
 | |
|                  verbose=False,
 | |
|                  contiguous_gradients=True,
 | |
|                  reduce_bucket_size=500000000,
 | |
|                  allgather_bucket_size=5000000000,
 | |
|                  reduce_scatter=True,
 | |
|                  overlap_comm=False,
 | |
|                  cpu_offload=False,
 | |
|                  clip_grad=0.0,
 | |
|                  allreduce_always_fp32=False,
 | |
|                  postscale_gradients=True,
 | |
|                  gradient_predivide_factor=1.0,
 | |
|                  gradient_accumulation_steps=1,
 | |
|                  ignore_unused_parameters=True,
 | |
|                  round_robin_gradients=False,
 | |
|                  fp16_master_weights_and_gradients=False):
 | |
|         # mpu = None is removed from the parameter list
 | |
|         # tensor parallel will be automatically detected later
 | |
| 
 | |
|         # LSG: default arguments for compatibility
 | |
|         has_moe_layers = False
 | |
|         partition_grads = True
 | |
|         expert_parallel_group = None
 | |
|         expert_data_parallel_group = None
 | |
|         self.timers = None
 | |
|         self.defaults = init_optimizer.defaults
 | |
| 
 | |
|         dp_process_group = gpc.get_group(dp_parallel_mode)
 | |
|         if gpc.get_world_size(dp_parallel_mode) == 1:
 | |
|             partition_grads = False  # for compatibility with dp size = 1
 | |
| 
 | |
|         self.verbose = verbose
 | |
| 
 | |
|         if dist.get_rank() == 0 and self.verbose:
 | |
|             print(f"Reduce bucket size {reduce_bucket_size}")
 | |
|             print(f"Allgather bucket size {allgather_bucket_size}")
 | |
|             print(f"CPU Offload: {cpu_offload}")
 | |
|             print(
 | |
|                 f'Round robin gradient partitioning: {round_robin_gradients}')
 | |
|         # The fused optimizer does all the work. We need this layer for two reason:
 | |
|         # 1. maintain same user API from apex.fp16_utils
 | |
|         # 2. keep common stuff here in case we need to add ne552w fused optimizer later
 | |
| 
 | |
|         # differences from apex.fp16_utils:
 | |
|         # - assume all model params in fp16
 | |
|         # - assume all params requires grad
 | |
|         # - flat by groups, not keeping state. TODO: remove state explicitly?
 | |
|         # - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
 | |
|         if not torch.cuda.is_available:
 | |
|             raise SystemError("Cannot use fp16 without CUDA.")
 | |
|         self.optimizer = init_optimizer
 | |
| 
 | |
|         # Load pre-built or JIT compile (un)flatten ops
 | |
|         util_ops = UtilsBuilder().load()
 | |
|         self.flatten = util_ops.flatten
 | |
|         self.unflatten = util_ops.unflatten
 | |
| 
 | |
|         # ZeRO stage 1 (False) or 2 (True)
 | |
|         self.partition_gradients = partition_grads
 | |
| 
 | |
|         self.reduce_scatter = reduce_scatter
 | |
| 
 | |
|         self.overlap_comm = overlap_comm
 | |
| 
 | |
|         self.cpu_offload = cpu_offload
 | |
| 
 | |
|         self.deepspeed_adam_offload = cpu_offload
 | |
| 
 | |
|         self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
 | |
| 
 | |
|         self.dp_process_group = dp_process_group
 | |
| 
 | |
|         # expert parallel group
 | |
|         self.ep_process_group = expert_parallel_group
 | |
| 
 | |
|         # data parallel group for experts
 | |
|         self.expert_dp_process_group = expert_data_parallel_group
 | |
| 
 | |
|         # data parallel size for non-experts
 | |
|         dp_size = dist.get_world_size(group=self.dp_process_group)
 | |
| 
 | |
|         # For MoE models this maybe different for different param group
 | |
|         # It will be modified during MoE setup later in the init
 | |
|         self.real_dp_process_group = [
 | |
|             dp_process_group for i in range(len(self.optimizer.param_groups))
 | |
|         ]
 | |
|         self.partition_count = [dp_size for i in range(
 | |
|             len(self.optimizer.param_groups))]
 | |
| 
 | |
|         self.is_gradient_accumulation_boundary = True
 | |
| 
 | |
|         # CPU-Offload requires contiguous gradients
 | |
|         self.contiguous_gradients = contiguous_gradients or cpu_offload
 | |
| 
 | |
|         self.has_moe_layers = has_moe_layers
 | |
| 
 | |
|         if self.has_moe_layers:
 | |
|             self._configure_moe_settings()
 | |
| 
 | |
|         if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_world_size(ParallelMode.TENSOR) == 1:
 | |
|             self.model_parallel_group = None
 | |
|             self.model_parallel_rank = 0
 | |
|         else:
 | |
|             self.model_parallel_group = gpc.get_group(ParallelMode.TENSOR)
 | |
|             self.model_parallel_rank = gpc.get_local_rank(ParallelMode.TENSOR)
 | |
| 
 | |
|         self.overflow = False
 | |
|         self.clip_grad = clip_grad
 | |
|         self.allreduce_always_fp32 = allreduce_always_fp32
 | |
|         self.gradient_predivide_factor = gradient_predivide_factor
 | |
|         self.postscale_gradients = postscale_gradients
 | |
|         self.gradient_accumulation_steps = gradient_accumulation_steps
 | |
|         self.micro_step_id = 0
 | |
|         self.ignore_unused_parameters = ignore_unused_parameters
 | |
|         self.round_robin_gradients = round_robin_gradients
 | |
| 
 | |
|         self.extra_large_param_to_reduce = None
 | |
|         self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
 | |
| 
 | |
|         if self.fp16_master_weights_and_gradients:
 | |
|             assert self.cpu_offload and type(self.optimizer) in [
 | |
|                 DeepSpeedCPUAdam], f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32. Currenty only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}. Either disable fp16_master_weights_and_gradients or enable ZeRO-2 Offload with DeepSpeedCPUAdam"
 | |
| 
 | |
|         if self.reduce_scatter:
 | |
|             assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
 | |
|             assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
 | |
|             assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
 | |
| 
 | |
|         # param flattened by groups
 | |
|         self.fp16_groups = []
 | |
|         self.fp16_groups_flat = []
 | |
| 
 | |
|         # param partitioned by data parallel degree
 | |
|         # this will contain a list of equal sized tensors
 | |
|         # each of which will be updated by a different process
 | |
|         self.parallel_partitioned_fp16_groups = []
 | |
| 
 | |
|         # a single 32-bit partition of the parallel partitioned parameters
 | |
|         # that this process will update
 | |
|         self.single_partition_of_fp32_groups = []
 | |
| 
 | |
|         # param partition info
 | |
| 
 | |
|         # These are the parameters in each group that will not be updated by this process directly
 | |
|         self.params_not_in_partition = []
 | |
| 
 | |
|         # These are the parameters that will be updated by this process directly
 | |
|         self.params_in_partition = []
 | |
| 
 | |
|         # Offset from the first paramter in the the self.params_in_partition
 | |
|         # the parameter boundaries may not align with partition boundaries
 | |
|         # so we need to keep track of the offset
 | |
|         self.first_offset = []
 | |
| 
 | |
|         # number of elements per partition in each group
 | |
|         self.partition_size = []
 | |
| 
 | |
|         # align nccl all-gather send buffers to 4-bye boundary
 | |
|         # 4-byte alignment/sizeof(fp16) = 2
 | |
|         self.nccl_start_alignment_factor = 2
 | |
| 
 | |
|         assert (
 | |
|             allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
 | |
| 
 | |
|         self.all_reduce_print = False
 | |
|         self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
 | |
| 
 | |
|         self.round_robin_fp16_groups = []
 | |
|         self.round_robin_fp6_indices = []
 | |
| 
 | |
|         # padding on each partition for alignment purposes
 | |
|         self.groups_padding = []
 | |
|         # loop to deal with groups
 | |
|         for i, param_group in enumerate(self.optimizer.param_groups):
 | |
|             partition_id = dist.get_rank(group=self.real_dp_process_group[i])
 | |
| 
 | |
|             # push this group to list before modify
 | |
|             # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
 | |
|             self.fp16_groups.append(param_group['params'])
 | |
| 
 | |
|             # Record padding required to align group to world size
 | |
|             if partition_id == dist.get_world_size(
 | |
|                     group=self.real_dp_process_group[i]) - 1:
 | |
|                 padding = get_alignment_padding(self.fp16_groups[i],
 | |
|                                                 self.partition_count[i])
 | |
|             else:
 | |
|                 padding = 0
 | |
|             self.groups_padding.append(padding)
 | |
| 
 | |
|             # not sure why apex was cloning the weights before flattening
 | |
|             # removing cloning here
 | |
| 
 | |
|             if self.verbose:
 | |
|                 report_memory_usage(f"Before moving param group {i} to CPU")
 | |
|             # move all the parameters to cpu to free up GPU space for creating flat buffer
 | |
|             move_to_cpu(self.fp16_groups[i])
 | |
|             if self.verbose:
 | |
|                 report_memory_usage(f"After moving param group {i} to CPU")
 | |
| 
 | |
|             # Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
 | |
|             # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
 | |
|             # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
 | |
|             # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
 | |
|             if self.round_robin_gradients:
 | |
|                 round_robin_tensors, round_robin_indices = self._round_robin_reorder(
 | |
|                     self.fp16_groups[i],
 | |
|                     dist.get_world_size(group=self.real_dp_process_group[i])
 | |
|                 )
 | |
|             else:
 | |
|                 round_robin_tensors = self.fp16_groups[i]
 | |
|                 round_robin_indices = list(range(len(self.fp16_groups[i])))
 | |
| 
 | |
|             self.round_robin_fp16_groups.append(round_robin_tensors)
 | |
|             self.round_robin_fp6_indices.append(round_robin_indices)
 | |
| 
 | |
|             # create flat buffer in CPU and move to GPU
 | |
|             self.fp16_groups_flat.append(
 | |
|                 self.flatten_dense_tensors_aligned(
 | |
|                     self.round_robin_fp16_groups[i],
 | |
|                     self.nccl_start_alignment_factor *
 | |
|                     dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
 | |
|                     torch.cuda.current_device()))
 | |
| 
 | |
|             if self.verbose:
 | |
|                 report_memory_usage(
 | |
|                     f"After flattening and moving param group {i} to GPU")
 | |
| 
 | |
|                 if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
 | |
|                     report_memory_usage(
 | |
|                         f"After Flattening and after emptying param group {i} cache")
 | |
| 
 | |
|             # set model fp16 weight to slices of flattened buffer
 | |
|             self._update_model_fp16_weights(i)
 | |
| 
 | |
|             # divide the flat weights into near equal partition equal to the data parallel degree
 | |
|             # each process will compute on a different part of the partition
 | |
|             data_parallel_partitions = self.get_data_parallel_partitions(
 | |
|                 self.fp16_groups_flat[i],
 | |
|                 i)
 | |
|             self.parallel_partitioned_fp16_groups.append(
 | |
|                 data_parallel_partitions)
 | |
| 
 | |
|             # verify that data partition start locations are 4-byte aligned
 | |
|             for partitioned_data in data_parallel_partitions:
 | |
|                 assert (partitioned_data.data_ptr() %
 | |
|                         (2 * self.nccl_start_alignment_factor) == 0)
 | |
| 
 | |
|             # a partition of the fp32 master weights that will be updated by this process
 | |
|             if not fp16_master_weights_and_gradients:
 | |
|                 self.single_partition_of_fp32_groups.append(
 | |
|                     self.parallel_partitioned_fp16_groups[i][partition_id].to(
 | |
|                         self.device).clone().float().detach())
 | |
|             else:
 | |
|                 self.single_partition_of_fp32_groups.append(
 | |
|                     self.parallel_partitioned_fp16_groups[i][partition_id].to(
 | |
|                         self.device).clone().half().detach())
 | |
| 
 | |
|             # modify optimizer of have flat master weight
 | |
|             self.single_partition_of_fp32_groups[
 | |
|                 i].requires_grad = True  # keep this in case internal optimizer uses it
 | |
|             param_group['params'] = [self.single_partition_of_fp32_groups[i]]
 | |
| 
 | |
|             partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
 | |
|                 group=self.real_dp_process_group[i])
 | |
|             params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
 | |
|                 self.round_robin_fp16_groups[i],
 | |
|                 partition_size,
 | |
|                 partition_id)
 | |
| 
 | |
|             self.partition_size.append(partition_size)
 | |
|             self.params_in_partition.append(params_in_partition)
 | |
|             self.params_not_in_partition.append(params_not_in_partition)
 | |
|             self.first_offset.append(first_offset)
 | |
| 
 | |
|         for rank in range(dist.get_world_size()):
 | |
|             if dist.get_rank() == rank and self.verbose:
 | |
|                 print(
 | |
|                     f"Rank: {rank} partition count {self.partition_count} and sizes{[(p.numel(), self.is_moe_param_group[i] if hasattr(self, 'is_moe_param_group') else False) for i, p in enumerate(self.single_partition_of_fp32_groups)]} "
 | |
|                 )
 | |
|                 dist.barrier()
 | |
|         # exit(0)
 | |
|         self.reduce_bucket_size = int(reduce_bucket_size)
 | |
|         self.allgather_bucket_size = int(allgather_bucket_size)
 | |
| 
 | |
|         self.reduction_event = torch.cuda.Event(
 | |
|             enable_timing=False, blocking=False)
 | |
|         self.reduction_stream = torch.cuda.Stream()
 | |
|         self.cpu_computation_stream = torch.cuda.Stream()
 | |
|         self.copy_grad_stream = torch.cuda.Stream()
 | |
|         self.callback_queued = False
 | |
| 
 | |
|         self.param_dict = {}
 | |
| 
 | |
|         # map between param_id and bool to specify if a param is in this partition
 | |
|         self.is_param_in_current_partition = {}
 | |
| 
 | |
|         self.grads_in_ipg_bucket = []
 | |
|         self.params_in_ipg_bucket = []
 | |
|         self.elements_in_ipg_bucket = 0
 | |
|         self.params_already_reduced = []
 | |
|         self._release_ipg_buffers()
 | |
|         self.previous_reduced_grads = None
 | |
|         self.ipg_bucket_has_moe_params = False
 | |
| 
 | |
|         # simplified param id
 | |
|         self.param_id = {}
 | |
| 
 | |
|         largest_param_numel = 0
 | |
|         count = 0
 | |
|         for i, params_group in enumerate(self.fp16_groups):
 | |
|             for param in params_group:
 | |
|                 unique_id = id(param)
 | |
|                 self.param_id[unique_id] = count
 | |
|                 self.param_dict[count] = param
 | |
|                 self.params_already_reduced.append(False)
 | |
|                 if param.numel() > largest_param_numel:
 | |
|                     largest_param_numel = param.numel()
 | |
|                 count = count + 1
 | |
| 
 | |
|         for param_group in self.params_in_partition:
 | |
|             for param in param_group:
 | |
|                 self.is_param_in_current_partition[self.get_param_id(
 | |
|                     param)] = True
 | |
| 
 | |
|         for param_group in self.params_not_in_partition:
 | |
|             for param in param_group:
 | |
|                 self.is_param_in_current_partition[self.get_param_id(
 | |
|                     param)] = False
 | |
| 
 | |
|         if self.cpu_offload:
 | |
|             self.accumulated_grads_in_cpu = {}
 | |
|             self.norm_for_param_grads = {}
 | |
|             self.local_overflow = False
 | |
|             self.grad_position = {}
 | |
|             self.temp_grad_buffer_for_cpu_offload = torch.zeros(
 | |
|                 largest_param_numel,
 | |
|                 device=self.device,
 | |
|                 dtype=self.dtype).pin_memory()
 | |
|             self.temp_grad_buffer_for_gpu_offload = torch.zeros(
 | |
|                 largest_param_numel,
 | |
|                 device=torch.cuda.current_device(),
 | |
|                 dtype=self.dtype)
 | |
| 
 | |
|             for i, params_group in enumerate(self.fp16_groups):
 | |
|                 self.get_grad_position(i,
 | |
|                                        self.params_in_partition[i],
 | |
|                                        self.first_offset[i],
 | |
|                                        self.partition_size[i])
 | |
| 
 | |
|         # mapping from parameter to partition that it belongs to
 | |
|         self.param_to_partition_ids = {}
 | |
| 
 | |
|         # stores if a partition has been reduced in this step
 | |
|         self.is_partition_reduced = {}
 | |
| 
 | |
|         # number of grads in partition that still need to be computed
 | |
|         self.remaining_grads_in_partition = {}
 | |
| 
 | |
|         # total number of grads in partition
 | |
|         self.total_grads_in_partition = {}
 | |
| 
 | |
|         # stores if a grad in a partition has been computed or not
 | |
|         self.is_grad_computed = {}
 | |
| 
 | |
|         # stores the offset at which a parameter gradient needs to be inserted in a partition
 | |
|         self.grad_partition_insertion_offset = {}
 | |
| 
 | |
|         # the offset in the gradient at which it must be inserted at the beginning of the partition
 | |
|         self.grad_start_offset = {}
 | |
| 
 | |
|         # will store the averaged gradients required by this partition
 | |
|         self.averaged_gradients = {}
 | |
| 
 | |
|         # store index of first parameter in each partition
 | |
|         self.first_param_index_in_partition = {}
 | |
| 
 | |
|         # initializes all data structures for implementing gradient partitioning
 | |
|         self.initialize_gradient_partitioning_data_structures()
 | |
| 
 | |
|         # resets the data structure value for the next backward propagation
 | |
|         self.reset_partition_gradient_structures()
 | |
| 
 | |
|         # creates backward hooks for gradient partitioning
 | |
|         if self.partition_gradients or self.overlap_comm:
 | |
|             self.create_reduce_and_remove_grad_hooks()
 | |
| 
 | |
|         # we may have a way of fusing dynamic scale. Do not support for now
 | |
|         if self.dtype == torch.float or not dynamic_loss_scale:
 | |
|             loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
 | |
| 
 | |
|             self.dynamic_loss_scale = False
 | |
|             self.loss_scaler = LossScaler(scale=loss_scale_value)
 | |
|             cur_iter = 0
 | |
|         else:
 | |
|             if dynamic_loss_args is None:
 | |
|                 self.loss_scaler = DynamicLossScaler()
 | |
|             else:
 | |
|                 self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
 | |
| 
 | |
|             self.dynamic_loss_scale = True
 | |
| 
 | |
|         if self.verbose:
 | |
|             report_memory_usage("Before initializing optimizer states")
 | |
|         self.initialize_optimizer_states()
 | |
|         if self.verbose:
 | |
|             report_memory_usage("After initializing optimizer states")
 | |
| 
 | |
|             if dist.get_rank() == 0:
 | |
|                 print(f"optimizer state initialized")
 | |
| 
 | |
|             if dist.get_rank(group=self.dp_process_group) == 0:
 | |
|                 report_memory_usage(f"After initializing ZeRO optimizer")
 | |
| 
 | |
|     def _configure_moe_settings(self):
 | |
|         assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
 | |
|         assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
 | |
| 
 | |
|         def is_moe_group(group):
 | |
|             return 'moe' in group and group['moe']
 | |
| 
 | |
|         assert any([is_moe_group(group) for group in
 | |
|                     self.optimizer.param_groups]), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
 | |
|         self.is_moe_param_group = []
 | |
|         for i, group in enumerate(self.optimizer.param_groups):
 | |
|             if is_moe_group(group):
 | |
|                 assert all(
 | |
|                     [is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params"
 | |
|                 self.real_dp_process_group[i] = self.expert_dp_process_group
 | |
|                 self.partition_count[i] = dist.get_world_size(
 | |
|                     group=self.expert_dp_process_group)
 | |
|                 self.is_moe_param_group.append(True)
 | |
|             else:
 | |
|                 self.is_moe_param_group.append(False)
 | |
| 
 | |
|         assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
 | |
|         assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
 | |
| 
 | |
|     def _update_model_fp16_weights(self, group_index):
 | |
|         updated_params = self.unflatten(self.fp16_groups_flat[group_index],
 | |
|                                         self.round_robin_fp16_groups[group_index])
 | |
|         for p, q in zip(self.round_robin_fp16_groups[group_index], updated_params):
 | |
|             p.data = q.data
 | |
| 
 | |
|         # set model fp16 weight to slices of reordered flattened buffer
 | |
|         for param_index, param in enumerate(self.fp16_groups[group_index]):
 | |
|             new_index = self.round_robin_fp6_indices[group_index][param_index]
 | |
|             param.data = self.round_robin_fp16_groups[group_index][new_index].data
 | |
| 
 | |
|     def _round_robin_reorder(self, tensor_list, num_partitions):
 | |
| 
 | |
|         # disable round robin if need to debug something
 | |
|         # return tensor_list, list(range(len(tensor_list)))
 | |
| 
 | |
|         partition_tensors = {}
 | |
| 
 | |
|         for i, tensor in enumerate(tensor_list):
 | |
|             j = i % num_partitions
 | |
|             if not j in partition_tensors:
 | |
|                 partition_tensors[j] = []
 | |
|             partition_tensors[j].append((i, tensor))
 | |
| 
 | |
|         reordered_tensors = []
 | |
|         reordered_indices = {}
 | |
| 
 | |
|         for partition_index in partition_tensors.keys():
 | |
|             for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
 | |
|                 reordered_indices[original_index] = len(reordered_tensors)
 | |
|                 reordered_tensors.append(tensor)
 | |
| 
 | |
|         return reordered_tensors, reordered_indices
 | |
| 
 | |
|     def _release_ipg_buffers(self):
 | |
|         if self.contiguous_gradients:
 | |
|             self.ipg_buffer = None
 | |
|             self.grads_in_partition = None
 | |
|             self.grads_in_partition_offset = 0
 | |
| 
 | |
|     def initialize_optimizer_states(self):
 | |
| 
 | |
|         for i, group in enumerate(self.fp16_groups):
 | |
|             single_grad_partition = torch.zeros(
 | |
|                 int(self.partition_size[i]),
 | |
|                 dtype=self.single_partition_of_fp32_groups[i].dtype,
 | |
|                 device=self.device)
 | |
|             self.single_partition_of_fp32_groups[
 | |
|                 i].grad = single_grad_partition.pin_memory(
 | |
|             ) if self.cpu_offload else single_grad_partition
 | |
| 
 | |
|         self.optimizer.step()
 | |
| 
 | |
|         if not self.cpu_offload:
 | |
|             for group in self.single_partition_of_fp32_groups:
 | |
|                 group.grad = None  # class init
 | |
| 
 | |
|         return
 | |
| 
 | |
|     #########################################################################
 | |
|     #################### ZeRO Stage 1 - reduce gradients ####################
 | |
|     #########################################################################
 | |
| 
 | |
|     def reduce_gradients(self, pipeline_parallel=False):
 | |
|         world_size = dist.get_world_size(self.dp_process_group)
 | |
|         my_rank = dist.get_rank(self.dp_process_group)
 | |
| 
 | |
|         # with PP we must create ipg buffer, since backward is handled outside zero
 | |
|         if pipeline_parallel and self.contiguous_gradients:
 | |
|             self.ipg_buffer = []
 | |
|             buf_0 = torch.empty(int(self.reduce_bucket_size),
 | |
|                                 dtype=self.dtype,
 | |
|                                 device=torch.cuda.current_device())
 | |
|             self.ipg_buffer.append(buf_0)
 | |
|             self.ipg_index = 0
 | |
| 
 | |
|         if not self.overlap_comm:
 | |
|             for i, group in enumerate(self.fp16_groups):
 | |
|                 for param in group:
 | |
|                     if param.grad is not None:
 | |
|                         self.reduce_ready_partitions_and_remove_grads(param, i)
 | |
| 
 | |
|         # reduce any pending grads in either hook/non-hook case
 | |
|         self.overlapping_partition_gradients_reduce_epilogue()
 | |
| 
 | |
|     #########################################################################
 | |
|     #########################ZeRO Partition Gradients########################
 | |
|     #########################################################################
 | |
| 
 | |
|     def get_first_param_index(self, group_id, param_group, partition_id):
 | |
|         for index, param in enumerate(param_group):
 | |
|             param_id = self.get_param_id(param)
 | |
|             if partition_id in self.param_to_partition_ids[group_id][param_id]:
 | |
|                 return index
 | |
|         return None
 | |
| 
 | |
|     def initialize_gradient_partitioning_data_structures(self):
 | |
| 
 | |
|         for i, param_group in enumerate(self.round_robin_fp16_groups):
 | |
| 
 | |
|             total_partitions = dist.get_world_size(
 | |
|                 group=self.real_dp_process_group[i])
 | |
| 
 | |
|             self.param_to_partition_ids[i] = {}
 | |
|             self.is_partition_reduced[i] = {}
 | |
|             self.total_grads_in_partition[i] = {}
 | |
|             self.remaining_grads_in_partition[i] = {}
 | |
|             self.is_grad_computed[i] = {}
 | |
|             self.grad_partition_insertion_offset[i] = {}
 | |
|             self.grad_start_offset[i] = {}
 | |
|             self.first_param_index_in_partition[i] = {}
 | |
| 
 | |
|             for partition_id in range(total_partitions):
 | |
|                 self.is_grad_computed[i][partition_id] = {}
 | |
|                 self.grad_partition_insertion_offset[i][partition_id] = {}
 | |
|                 self.grad_start_offset[i][partition_id] = {}
 | |
|                 self.total_grads_in_partition[i][partition_id] = 0
 | |
|                 self.initialize_gradient_partition(
 | |
|                     i, param_group, partition_id)
 | |
|                 self.is_partition_reduced[i][partition_id] = False
 | |
|                 self.first_param_index_in_partition[i][
 | |
|                     partition_id] = self.get_first_param_index(
 | |
|                     i,
 | |
|                     param_group,
 | |
|                     partition_id)
 | |
| 
 | |
|     def independent_gradient_partition_epilogue(self):
 | |
|         if self.verbose:
 | |
|             self.report_ipg_memory_usage(
 | |
|                 f"In ipg_epilogue before reduce_ipg_grads", 0)
 | |
|         self.reduce_ipg_grads()
 | |
|         if self.verbose:
 | |
|             self.report_ipg_memory_usage(
 | |
|                 f"In ipg_epilogue after reduce_ipg_grads", 0)
 | |
| 
 | |
|         # if dist.get_rank() == 0:
 | |
|         #    print()("Params already reduced %s", self.params_already_reduced)
 | |
|         for i in range(len(self.params_already_reduced)):
 | |
|             self.params_already_reduced[i] = False
 | |
| 
 | |
|         if self.overlap_comm:
 | |
|             torch.cuda.synchronize()
 | |
|             # It is safe to clear previously reduced grads of other partitions
 | |
|             self._clear_previous_reduced_grads()
 | |
| 
 | |
|         if self.cpu_offload is False:
 | |
|             for i, _ in enumerate(self.fp16_groups):
 | |
| 
 | |
|                 if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
 | |
|                     self.averaged_gradients[i] = self.get_flat_partition(
 | |
|                         self.params_in_partition[i],
 | |
|                         self.first_offset[i],
 | |
|                         self.partition_size[i],
 | |
|                         dtype=self.dtype,
 | |
|                         device=torch.cuda.current_device(),
 | |
|                         return_tensor_list=True)
 | |
|                 else:
 | |
|                     avg_new = self.get_flat_partition(self.params_in_partition[i],
 | |
|                                                       self.first_offset[i],
 | |
|                                                       self.partition_size[i],
 | |
|                                                       dtype=self.dtype,
 | |
|                                                       device=torch.cuda.current_device(),
 | |
|                                                       return_tensor_list=True)
 | |
| 
 | |
|                     for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
 | |
|                         accumulated_grad.add_(new_avg_grad)
 | |
| 
 | |
|         self._release_ipg_buffers()
 | |
| 
 | |
|         # No need to keep the gradients anymore.
 | |
|         # All gradients required by the step
 | |
|         # are in self.averaged_gradients
 | |
|         self.zero_grad()
 | |
| 
 | |
|         if self.verbose:
 | |
|             report_memory_usage(f"End ipg_epilogue")
 | |
| 
 | |
|     # resets all partition to no reduced
 | |
|     # sets remaining grads to the total number of grads in each partition
 | |
|     # set is grad computed to false for all grads in partition
 | |
|     def reset_partition_gradient_structures(self):
 | |
|         for i, _ in enumerate(self.fp16_groups):
 | |
|             total_partitions = dist.get_world_size(
 | |
|                 group=self.real_dp_process_group[i])
 | |
|             for partition_id in range(total_partitions):
 | |
|                 self.is_partition_reduced[i][partition_id] = False
 | |
|                 self.remaining_grads_in_partition[i][
 | |
|                     partition_id] = self.total_grads_in_partition[i][partition_id]
 | |
| 
 | |
|                 for param_id in self.is_grad_computed[i][partition_id]:
 | |
|                     self.is_grad_computed[i][partition_id][param_id] = False
 | |
| 
 | |
|     def initialize_gradient_partition(self, i, param_group, partition_id):
 | |
|         def set_key_value_list(dictionary, key, value):
 | |
|             if key in dictionary:
 | |
|                 dictionary[key].append(value)
 | |
|             else:
 | |
|                 dictionary[key] = [value]
 | |
| 
 | |
|         def increment_value(dictionary, key):
 | |
|             if key in dictionary:
 | |
|                 dictionary[key] += 1
 | |
|             else:
 | |
|                 dictionary[key] = 1
 | |
| 
 | |
|         partition_size = self.partition_size[i]
 | |
| 
 | |
|         start_index = partition_size * partition_id
 | |
|         end_index = partition_size * (partition_id + 1)
 | |
| 
 | |
|         current_index = 0
 | |
|         first_offset = 0
 | |
| 
 | |
|         for param in param_group:
 | |
| 
 | |
|             param_size = param.numel()
 | |
|             param_id = self.get_param_id(param)
 | |
| 
 | |
|             if (current_index >= start_index and current_index < end_index):
 | |
|                 set_key_value_list(self.param_to_partition_ids[i],
 | |
|                                    param_id,
 | |
|                                    partition_id)
 | |
|                 increment_value(self.total_grads_in_partition[i], partition_id)
 | |
| 
 | |
|                 self.is_grad_computed[i][partition_id][param_id] = False
 | |
| 
 | |
|                 self.grad_partition_insertion_offset[i][partition_id][
 | |
|                     param_id] = current_index - start_index
 | |
|                 self.grad_start_offset[i][partition_id][param_id] = 0
 | |
| 
 | |
|             elif start_index > current_index and start_index < (current_index +
 | |
|                                                                 param_size):
 | |
|                 assert (
 | |
|                     first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
 | |
|                 first_offset = start_index - current_index
 | |
| 
 | |
|                 set_key_value_list(self.param_to_partition_ids[i],
 | |
|                                    param_id,
 | |
|                                    partition_id)
 | |
|                 increment_value(self.total_grads_in_partition[i], partition_id)
 | |
| 
 | |
|                 self.is_grad_computed[i][partition_id][param_id] = False
 | |
| 
 | |
|                 self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
 | |
|                 self.grad_start_offset[i][partition_id][param_id] = first_offset
 | |
| 
 | |
|             current_index = current_index + param_size
 | |
| 
 | |
|     def overlapping_partition_gradients_reduce_epilogue(self):
 | |
|         self.independent_gradient_partition_epilogue()
 | |
| 
 | |
|     def create_reduce_and_remove_grad_hooks(self):
 | |
|         self.grad_accs = []
 | |
|         for i, param_group in enumerate(self.fp16_groups):
 | |
|             for param in param_group:
 | |
|                 if param.requires_grad:
 | |
|                     def wrapper(param, i):
 | |
|                         param_tmp = param.expand_as(param)
 | |
|                         grad_acc = param_tmp.grad_fn.next_functions[0][0]
 | |
| 
 | |
|                         def reduce_partition_and_remove_grads(*notneeded):
 | |
|                             self.reduce_ready_partitions_and_remove_grads(
 | |
|                                 param, i)
 | |
| 
 | |
|                         grad_acc.register_hook(
 | |
|                             reduce_partition_and_remove_grads)
 | |
|                         self.grad_accs.append(grad_acc)
 | |
| 
 | |
|                     wrapper(param, i)
 | |
| 
 | |
|     def get_param_id(self, param):
 | |
|         unique_id = id(param)
 | |
|         return self.param_id[unique_id]
 | |
| 
 | |
|     def report_ipg_memory_usage(self, tag, param_elems):
 | |
|         elem_count = self.elements_in_ipg_bucket + param_elems
 | |
|         percent_of_bucket_size = (
 | |
|             100.0 * elem_count) // self.reduce_bucket_size
 | |
|         if self.verbose:
 | |
|             report_memory_usage(
 | |
|                 f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
 | |
|             )
 | |
| 
 | |
|     # create a flat tensor aligned at the alignment boundary
 | |
|     def flatten_dense_tensors_aligned(self, tensor_list, alignment):
 | |
|         num_elements = 0
 | |
|         for tensor in tensor_list:
 | |
|             num_elements = num_elements + tensor.numel()
 | |
| 
 | |
|         remaining = num_elements % alignment
 | |
| 
 | |
|         if remaining:
 | |
|             elements_to_add = alignment - remaining
 | |
|             pad_tensor = torch.zeros(elements_to_add,
 | |
|                                      device=tensor_list[0].device,
 | |
|                                      dtype=tensor_list[0].dtype)
 | |
|             padded_tensor_list = tensor_list + [pad_tensor]
 | |
| 
 | |
|             num_elements = num_elements + elements_to_add
 | |
|         else:
 | |
|             padded_tensor_list = tensor_list
 | |
| 
 | |
|         return self.flatten(padded_tensor_list)
 | |
| 
 | |
|     ############### Independent Partition Gradient ########################
 | |
|     def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
 | |
|         if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
 | |
|             self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
 | |
|                                          param.numel())
 | |
|             self.reduce_ipg_grads()
 | |
|             if self.contiguous_gradients and self.overlap_comm:
 | |
|                 # Swap ipg_index between 0 and 1
 | |
|                 self.ipg_index = 1 - self.ipg_index
 | |
| 
 | |
|             self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
 | |
|                                          param.numel())
 | |
| 
 | |
|         param_id = self.get_param_id(param)
 | |
|         assert self.params_already_reduced[param_id] == False, \
 | |
|             f"The parameter {param_id} has already been reduced. \
 | |
|             Gradient computed twice for this partition. \
 | |
|             Multiple gradient reduction is currently not supported"
 | |
| 
 | |
|         if param.numel() > self.reduce_bucket_size:
 | |
|             self.extra_large_param_to_reduce = param
 | |
| 
 | |
|         elif self.contiguous_gradients:
 | |
|             # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
 | |
|             new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
 | |
|                 0,
 | |
|                 self.elements_in_ipg_bucket,
 | |
|                 param.numel())
 | |
|             new_grad_tensor.copy_(param.grad.view(-1))
 | |
|             param.grad.data = new_grad_tensor.data.view_as(param.grad)
 | |
| 
 | |
|         self.elements_in_ipg_bucket += param.numel()
 | |
| 
 | |
|         assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
 | |
| 
 | |
|         self.grads_in_ipg_bucket.append(param.grad)
 | |
|         self.params_in_ipg_bucket.append((i, param, param_id))
 | |
| 
 | |
|         # make sure the average tensor function knows how to average the gradients
 | |
|         if is_moe_param(param):
 | |
|             self.ipg_bucket_has_moe_params = True
 | |
| 
 | |
|         self.report_ipg_memory_usage("End ipg_remove_grads", 0)
 | |
| 
 | |
|     def print_rank_0(self, message):
 | |
|         if dist.get_rank() == 0 and self.verbose:
 | |
|             print(message)
 | |
| 
 | |
|     def gradient_reduction_w_predivide(self, tensor):
 | |
| 
 | |
|         dp_world_size = dist.get_world_size(group=self.dp_process_group)
 | |
| 
 | |
|         tensor_to_allreduce = tensor
 | |
| 
 | |
|         if self.allreduce_always_fp32:
 | |
|             tensor_to_allreduce = tensor.float()
 | |
| 
 | |
|         if self.postscale_gradients:
 | |
|             if self.gradient_predivide_factor != 1.0:
 | |
|                 tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
 | |
| 
 | |
|             dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
 | |
| 
 | |
|             if self.gradient_predivide_factor != dp_world_size:
 | |
|                 tensor_to_allreduce.mul_(
 | |
|                     self.gradient_predivide_factor / dp_world_size)
 | |
|         else:
 | |
|             tensor_to_allreduce.div_(dp_world_size)
 | |
|             dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
 | |
| 
 | |
|         if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
 | |
|             tensor.copy_(tensor_to_allreduce)
 | |
| 
 | |
|         return tensor
 | |
| 
 | |
|     def average_tensor(self, tensor):
 | |
|         if self.overlap_comm:
 | |
|             torch.cuda.synchronize()
 | |
|             stream = self.reduction_stream
 | |
|         else:
 | |
|             stream = torch.cuda.current_stream()
 | |
| 
 | |
|         with torch.cuda.stream(stream):
 | |
|             if not self.reduce_scatter:
 | |
|                 self.gradient_reduction_w_predivide(tensor)
 | |
|                 return
 | |
| 
 | |
|             # Accumulate destination ranks and bucket offsets for each gradient slice.
 | |
|             # Note: potential future optimization, record access pattern of parameters
 | |
|             # in backward pass and partition gradients w.r.t. access pattern so that our
 | |
|             # bucket is guaranteed to be contiguous w.r.t. ranks
 | |
|             rank_and_offsets = []
 | |
|             real_dp_process_group = []
 | |
|             curr_size = 0
 | |
|             prev_id = -1
 | |
| 
 | |
|             process_group = self.dp_process_group
 | |
|             # count = 0
 | |
|             for i, param, param_id in self.params_in_ipg_bucket:
 | |
| 
 | |
|                 process_group = self.dp_process_group
 | |
|                 # Averages gradients at parameter level if ipg has a moe param
 | |
|                 # Otherwise averaging is done at the entire buffer level at the end of the loop
 | |
|                 if self.ipg_bucket_has_moe_params:
 | |
|                     process_group = self.expert_dp_process_group if is_moe_param(
 | |
|                         param) else self.dp_process_group
 | |
|                     param.grad.data.div_(
 | |
|                         dist.get_world_size(group=process_group))
 | |
| 
 | |
|                 partition_ids = self.param_to_partition_ids[i][param_id]
 | |
|                 partition_size = self.partition_size[i]
 | |
|                 # Get all partition ids + their offsets
 | |
|                 partition_ids_w_offsets = []
 | |
|                 for partition_id in partition_ids:
 | |
|                     offset = self.grad_start_offset[i][partition_id][param_id]
 | |
|                     partition_ids_w_offsets.append((partition_id, offset))
 | |
|                 partition_ids_w_offsets.sort(key=lambda t: t[1])
 | |
| 
 | |
|                 # Calculate rank and offsets for grad slices
 | |
|                 for idx in range(len(partition_ids_w_offsets)):
 | |
|                     partition_id, offset = partition_ids_w_offsets[idx]
 | |
| 
 | |
|                     # if dist.get_rank() == 0 and count < 100:
 | |
|                     #     print(f"Rank {dist.get_rank()} rank offet id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
 | |
|                     # count += 1
 | |
| 
 | |
|                     # Calculate numel for grad slice depending on partition location
 | |
|                     if idx == len(partition_ids_w_offsets) - 1:
 | |
|                         # Last partition_id uses its own offset
 | |
|                         numel = param.numel() - offset
 | |
|                     else:
 | |
|                         # Set numel to next partition's offset
 | |
|                         numel = partition_ids_w_offsets[idx + 1][1] - offset
 | |
| 
 | |
|                     # Merge bucket ranges if they belong to the same rank
 | |
|                     if partition_id == prev_id:
 | |
|                         prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
 | |
|                         rank_and_offsets[-1] = (prev_pid,
 | |
|                                                 prev_size, prev_numel + numel)
 | |
|                     else:
 | |
|                         rank_and_offsets.append(
 | |
|                             (partition_id, curr_size, numel))
 | |
|                         real_dp_process_group.append(process_group)
 | |
|                     curr_size += numel
 | |
|                     prev_id = partition_id
 | |
| 
 | |
|             if not self.ipg_bucket_has_moe_params:
 | |
|                 tensor.div_(dist.get_world_size(group=self.dp_process_group))
 | |
| 
 | |
|             async_handles = []
 | |
|             for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
 | |
|                 grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
 | |
|                 # if dist.get_rank() == 0:
 | |
|                 #     print(f"Rank {dist.get_rank()} rank offet id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
 | |
|                 # dist.barrier()
 | |
|                 # dist.barrier()
 | |
|                 dst_rank = _get_global_rank(real_dp_process_group[i], dst)
 | |
|                 async_handle = dist.reduce(grad_slice,
 | |
|                                            dst=dst_rank,
 | |
|                                            group=real_dp_process_group[i],
 | |
|                                            async_op=True)
 | |
|                 async_handles.append(async_handle)
 | |
| 
 | |
|             for handle in async_handles:
 | |
|                 handle.wait()
 | |
| 
 | |
|     ##############################################################################
 | |
|     ############################# CPU Offload Methods#############################
 | |
|     ##############################################################################
 | |
|     def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
 | |
|         current_offset = 0
 | |
| 
 | |
|         for i, tensor in enumerate(tensor_list):
 | |
|             param_id = self.get_param_id(tensor)
 | |
|             param_start_offset = 0
 | |
| 
 | |
|             num_elements = tensor.numel()
 | |
|             tensor_offset = 0
 | |
| 
 | |
|             # we need to offset to get to the right element
 | |
|             if i == 0 and first_offset > 0:
 | |
|                 tensor_offset = first_offset
 | |
|                 num_elements = num_elements - tensor_offset
 | |
|                 param_start_offset = first_offset
 | |
| 
 | |
|             # we dont need all elements of the tensor
 | |
|             if num_elements > (partition_size - current_offset):
 | |
|                 num_elements = partition_size - current_offset
 | |
| 
 | |
|             self.grad_position[param_id] = [
 | |
|                 int(group_id),
 | |
|                 int(param_start_offset),
 | |
|                 int(current_offset),
 | |
|                 int(num_elements)
 | |
|             ]
 | |
|             current_offset += num_elements
 | |
| 
 | |
|     def update_overflow_tracker_for_param_grad(self, param):
 | |
|         if param.grad is not None and self._has_inf_or_nan(param.grad.data):
 | |
|             self.local_overflow = True
 | |
| 
 | |
|     def async_accumulate_grad_in_cpu_via_gpu(self, param):
 | |
|         param_id = self.get_param_id(param)
 | |
| 
 | |
|         [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
 | |
| 
 | |
|         # copy to a preexisiting buffer to avoid memory allocation penalty
 | |
|         dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
 | |
|             0,
 | |
|             0,
 | |
|             param.numel())
 | |
| 
 | |
|         # buffer for storing gradients for this parameter in CPU
 | |
|         def buffer_to_accumulate_to_in_cpu():
 | |
|             if not self.fp16_master_weights_and_gradients:
 | |
|                 return torch.zeros(param.numel(),
 | |
|                                    dtype=param.dtype,
 | |
|                                    device=self.device).pin_memory()
 | |
|             else:
 | |
|                 return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
 | |
|                     0,
 | |
|                     dest_offset,
 | |
|                     num_elements)
 | |
| 
 | |
|         # accumulate gradients into param.grad or parts of it that belongs to this parittion
 | |
|         def accumulate_gradients():
 | |
|             if not self.fp16_master_weights_and_gradients:
 | |
|                 dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
 | |
|                                   non_blocking=True)
 | |
|                 param.grad.data.view(-1).add_(dest_buffer)
 | |
|             else:
 | |
|                 dest_buffer.narrow(0,
 | |
|                                    source_offset,
 | |
|                                    num_elements).copy_(
 | |
|                     self.accumulated_grads_in_cpu[param_id].view(-1),
 | |
|                     non_blocking=True)
 | |
|                 param.grad.data.view(-1).narrow(
 | |
|                     0,
 | |
|                     source_offset,
 | |
|                     num_elements).add_(dest_buffer.narrow(0,
 | |
|                                                           source_offset,
 | |
|                                                           num_elements))
 | |
| 
 | |
|         # move accumulated gradients back to CPU
 | |
|         def copy_gradients_to_cpu():
 | |
|             if not self.fp16_master_weights_and_gradients:
 | |
|                 self.accumulated_grads_in_cpu[param_id].data.copy_(
 | |
|                     param.grad.data.view(-1),
 | |
|                     non_blocking=True)
 | |
|             else:
 | |
|                 self.accumulated_grads_in_cpu[param_id].data.copy_(
 | |
|                     param.grad.data.view(-1).narrow(0,
 | |
|                                                     source_offset,
 | |
|                                                     num_elements),
 | |
|                     non_blocking=True)
 | |
| 
 | |
|         if param_id not in self.accumulated_grads_in_cpu:
 | |
|             self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu(
 | |
|             )
 | |
| 
 | |
|         if self.micro_step_id > 0:
 | |
|             accumulate_gradients()
 | |
| 
 | |
|         # at the boundary we will send 32bit directly
 | |
|         if not self.is_gradient_accumulation_boundary:
 | |
|             copy_gradients_to_cpu()
 | |
| 
 | |
|     def set_norm_for_param_grad(self, param):
 | |
|         param_id = self.get_param_id(param)
 | |
|         accumulated_grad = self.accumulated_grads_in_cpu[
 | |
|             param_id] if self.gradient_accumulation_steps > 1 else param.grad
 | |
| 
 | |
|         [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
 | |
| 
 | |
|         start = source_offset
 | |
|         accumulated_grad = accumulated_grad.view(
 | |
|             -1).narrow(0, start, num_elements)
 | |
| 
 | |
|         self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
 | |
|             2)
 | |
| 
 | |
|     def set_norm_for_param_grad_in_gpu(self, param):
 | |
|         param_id = self.get_param_id(param)
 | |
|         accumulated_grad = param.grad
 | |
| 
 | |
|         [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
 | |
| 
 | |
|         start = source_offset
 | |
|         accumulated_grad = accumulated_grad.view(
 | |
|             -1).narrow(0, start, num_elements)
 | |
| 
 | |
|         self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
 | |
|             2)
 | |
| 
 | |
|     def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
 | |
|         param_id = self.get_param_id(param)
 | |
| 
 | |
|         [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
 | |
| 
 | |
|         dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
 | |
|             0,
 | |
|             dest_offset,
 | |
|             num_elements)
 | |
| 
 | |
|         src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
 | |
|         if not self.fp16_master_weights_and_gradients:
 | |
|             src_tensor = src_tensor.float()
 | |
| 
 | |
|         dest_tensor.copy_(src_tensor, non_blocking=True)
 | |
|         param.grad = None  # offload only
 | |
| 
 | |
|     def complete_grad_norm_calculation_for_cpu_offload(self, params):
 | |
|         total_norm = 0.0
 | |
|         norm_type = 2.0
 | |
|         for p in params:
 | |
|             if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
 | |
|                 param_id = self.get_param_id(p)
 | |
|                 # as some model have trainable parameters but skipped in training,
 | |
|                 # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
 | |
|                 # so they have no norm_for_param_grads
 | |
|                 if param_id in self.norm_for_param_grads:
 | |
|                     param_norm = self.norm_for_param_grads[param_id]
 | |
|                     total_norm += param_norm.item() ** 2
 | |
|                 else:
 | |
|                     # As unused parameters in modules may not be expected sometimes,
 | |
|                     # add an explicit error msg when it occurred and an option to
 | |
|                     # avoid the error
 | |
|                     assert self.ignore_unused_parameters, """
 | |
|                         This assert indicates that your module has parameters that
 | |
|                         were not used in producing loss.
 | |
|                         You can avoid this assert by
 | |
|                         (1) enable ignore_unused_parameters option in zero_optimization config;
 | |
|                         (2) making sure all trainable parameters and `forward` function
 | |
|                             outputs participate in calculating loss.
 | |
|                     """
 | |
| 
 | |
|         # Sum across all model parallel GPUs.
 | |
|         total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
 | |
| 
 | |
|         torch.distributed.all_reduce(total_norm_cuda,
 | |
|                                      op=torch.distributed.ReduceOp.SUM,
 | |
|                                      group=self.dp_process_group)
 | |
| 
 | |
|         self._model_parallel_all_reduce(tensor=total_norm_cuda,
 | |
|                                         op=torch.distributed.ReduceOp.SUM)
 | |
| 
 | |
|         total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
 | |
| 
 | |
|         if total_norm == float(
 | |
|                 'inf') or total_norm == -float('inf') or total_norm != total_norm:
 | |
|             total_norm = -1
 | |
| 
 | |
|         return total_norm
 | |
| 
 | |
|     ############################################################################################
 | |
| 
 | |
|     def copy_grads_in_partition(self, param):
 | |
|         if self.cpu_offload:
 | |
| 
 | |
|             if self.gradient_accumulation_steps > 1:
 | |
|                 self.async_accumulate_grad_in_cpu_via_gpu(param)
 | |
| 
 | |
|             if self.is_gradient_accumulation_boundary:
 | |
|                 self.set_norm_for_param_grad_in_gpu(param)
 | |
| 
 | |
|                 self.update_overflow_tracker_for_param_grad(param)
 | |
| 
 | |
|                 self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
 | |
| 
 | |
|             return
 | |
|         # print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
 | |
|         if self.grads_in_partition is None:
 | |
|             self.grads_in_partition_offset = 0
 | |
|             total_size = 0
 | |
|             for group in self.params_in_partition:
 | |
|                 for param_in_partition in group:
 | |
|                     total_size += param_in_partition.numel()
 | |
| 
 | |
|             if self.verbose:
 | |
|                 report_memory_usage(
 | |
|                     f"before copying {total_size} gradients into partition")
 | |
|             self.grads_in_partition = torch.empty(int(total_size),
 | |
|                                                   dtype=self.dtype,
 | |
|                                                   device=torch.cuda.current_device())
 | |
| 
 | |
|             if self.verbose:
 | |
|                 report_memory_usage(
 | |
|                     f"after copying {total_size} gradients into partition")
 | |
| 
 | |
|         # The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
 | |
|         new_grad_tensor = self.grads_in_partition.view(-1).narrow(
 | |
|             0,
 | |
|             self.grads_in_partition_offset,
 | |
|             param.numel())
 | |
|         new_grad_tensor.copy_(param.grad.view(-1))
 | |
|         param.grad.data = new_grad_tensor.data.view_as(param.grad)
 | |
|         # print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
 | |
|         self.grads_in_partition_offset += param.numel()
 | |
| 
 | |
|     def reduce_ipg_grads(self):
 | |
|         if self.contiguous_gradients:
 | |
|             if self.extra_large_param_to_reduce is not None:
 | |
|                 assert len(
 | |
|                     self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
 | |
|                 _, _, param_id = self.params_in_ipg_bucket[0]
 | |
|                 assert self.get_param_id(
 | |
|                     self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param"
 | |
|                 self.average_tensor(
 | |
|                     self.extra_large_param_to_reduce.grad.view(-1))
 | |
|                 self.extra_large_param_to_reduce = None
 | |
|             else:
 | |
|                 self.average_tensor(self.ipg_buffer[self.ipg_index])
 | |
|         else:
 | |
|             self.buffered_reduce_fallback(
 | |
|                 None,
 | |
|                 self.grads_in_ipg_bucket,
 | |
|                 elements_per_buffer=self.elements_in_ipg_bucket)
 | |
| 
 | |
|         if self.overlap_comm:
 | |
|             stream = self.reduction_stream
 | |
|         elif self.cpu_offload:
 | |
|             # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
 | |
|             #            torch.cuda.synchronize()
 | |
|             #            stream = self.copy_grad_stream
 | |
|             stream = torch.cuda.current_stream()
 | |
|         else:
 | |
|             stream = torch.cuda.current_stream()
 | |
| 
 | |
|         with torch.cuda.stream(stream):
 | |
|             for _, param, param_id in self.params_in_ipg_bucket:
 | |
| 
 | |
|                 assert self.params_already_reduced[param_id] == False, \
 | |
|                     f"The parameter {param_id} has already been reduced. \
 | |
|                     Gradient computed twice for this partition. \
 | |
|                     Multiple gradient reduction is currently not supported"
 | |
| 
 | |
|                 self.params_already_reduced[param_id] = True
 | |
| 
 | |
|                 if self.partition_gradients:
 | |
|                     if not self.is_param_in_current_partition[param_id]:
 | |
|                         if self.overlap_comm and self.contiguous_gradients is False:
 | |
|                             # Clear grads of other partitions during the next reduction
 | |
|                             # to avoid clearing them before the reduction is complete.
 | |
|                             if self.previous_reduced_grads is None:
 | |
|                                 self.previous_reduced_grads = []
 | |
|                             self.previous_reduced_grads.append(param)
 | |
|                         else:
 | |
|                             param.grad = None  # only if self.partition_gradients
 | |
|                     elif self.contiguous_gradients:
 | |
|                         self.copy_grads_in_partition(param)
 | |
| 
 | |
|         self.grads_in_ipg_bucket = []
 | |
|         self.params_in_ipg_bucket = []
 | |
|         self.ipg_bucket_has_moe_params = False
 | |
|         self.elements_in_ipg_bucket = 0
 | |
|         #####################################################################
 | |
| 
 | |
|     def reduce_ready_partitions_and_remove_grads(self, param, i):
 | |
|         if self.partition_gradients or self.is_gradient_accumulation_boundary:
 | |
|             self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
 | |
| 
 | |
|     def zero_reduced_gradients(self, partition_id, i):
 | |
|         def are_all_related_partitions_reduced(params_id):
 | |
|             for partition_id in self.param_to_partition_ids[i][params_id]:
 | |
|                 if not self.is_partition_reduced[i][partition_id]:
 | |
|                     return False
 | |
|             return True
 | |
| 
 | |
|         for params_id in self.is_grad_computed[i][partition_id]:
 | |
|             if are_all_related_partitions_reduced(params_id):
 | |
|                 self.param_dict[params_id].grad = None  # dead code
 | |
| 
 | |
|     def flatten_and_print(self, message, tensors, start=0, n=5):
 | |
|         flatten_tensor = self.flatten(tensors)
 | |
| 
 | |
|         def print_func():
 | |
|             print(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
 | |
| 
 | |
|         self.sequential_execution(print_func, message)
 | |
| 
 | |
|     def get_grads_to_reduce(self, i, partition_id):
 | |
|         def get_reducable_portion(key):
 | |
|             grad = self.param_dict[key].grad
 | |
|             total_elements = grad.numel()
 | |
|             start = self.grad_start_offset[i][partition_id][key]
 | |
|             num_elements = min(
 | |
|                 total_elements - start,
 | |
|                 self.partition_size[i] -
 | |
|                 self.grad_partition_insertion_offset[i][partition_id][key])
 | |
|             if not pg_correctness_test:
 | |
|                 if num_elements == total_elements:
 | |
|                     return grad
 | |
|                 else:
 | |
|                     return grad.contiguous().view(-1).narrow(0,
 | |
|                                                              int(start),
 | |
|                                                              int(num_elements))
 | |
|             else:
 | |
|                 if num_elements == total_elements:
 | |
|                     return grad.clone()
 | |
|                 else:
 | |
|                     return grad.clone().contiguous().view(-1).narrow(
 | |
|                         0,
 | |
|                         int(start),
 | |
|                         int(num_elements))
 | |
| 
 | |
|         grads_to_reduce = []
 | |
|         for key in self.is_grad_computed[i][partition_id]:
 | |
|             grad = get_reducable_portion(key)
 | |
|             grads_to_reduce.append(grad)
 | |
|         return grads_to_reduce
 | |
| 
 | |
|     def sequential_execution(self, function, message, group=None):
 | |
|         if group is None:
 | |
|             group = self.dp_process_group
 | |
|         if dist.get_rank(group=group) == 0:
 | |
|             print(message)
 | |
|         for id in range(dist.get_world_size(group=group)):
 | |
|             if id == dist.get_rank(group=group):
 | |
|                 function()
 | |
|             dist.barrier(group=group)
 | |
| 
 | |
|     def set_none_gradients_to_zero(self, i, partition_id):
 | |
|         for param_id in self.is_grad_computed[i][partition_id]:
 | |
|             param = self.param_dict[param_id]
 | |
|             if param.grad is None:
 | |
|                 param.grad = torch.zero_like(param)
 | |
| 
 | |
|     ######################Reduction Related Methods##############################
 | |
| 
 | |
|     def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
 | |
|         rank = None
 | |
|         tensor = self.flatten(bucket)
 | |
| 
 | |
|         tensor_to_allreduce = tensor
 | |
| 
 | |
|         if pg_correctness_test:
 | |
|             allreduce_always_fp32 = True
 | |
| 
 | |
|         if allreduce_always_fp32:
 | |
|             tensor_to_allreduce = tensor.float()
 | |
| 
 | |
|         tensor_to_allreduce.div_(
 | |
|             dist.get_world_size(group=self.dp_process_group))
 | |
| 
 | |
|         if rank is None:
 | |
|             #    "All Reducing"
 | |
|             dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
 | |
|         else:
 | |
|             global_rank = _get_global_rank(self.dp_process_group, rank)
 | |
|             dist.reduce(tensor_to_allreduce, global_rank,
 | |
|                         group=self.dp_process_group)
 | |
| 
 | |
|         if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
 | |
|             if rank is None or rank == dist.get_rank(group=self.dp_process_group):
 | |
|                 tensor.copy_(tensor_to_allreduce)
 | |
| 
 | |
|         return tensor
 | |
| 
 | |
|     def _clear_previous_reduced_grads(self):
 | |
|         if self.previous_reduced_grads is not None:
 | |
|             for param in self.previous_reduced_grads:
 | |
|                 param.grad = None  # overlap enabled
 | |
|             self.previous_reduced_grads = None
 | |
| 
 | |
|     # if rank is specified do a reduction instead of an allreduce
 | |
|     def allreduce_and_copy(self, small_bucket, rank=None, log=None):
 | |
|         if self.overlap_comm:
 | |
|             torch.cuda.synchronize()
 | |
|             # It is safe to clear the previously reduced grads of other partitions
 | |
|             self._clear_previous_reduced_grads()
 | |
|             stream = self.reduction_stream
 | |
|         else:
 | |
|             stream = torch.cuda.current_stream()
 | |
| 
 | |
|         with torch.cuda.stream(stream):
 | |
|             allreduced = self.allreduce_bucket(
 | |
|                 small_bucket, rank=rank, log=log)
 | |
|             if rank is None or rank == dist.get_rank(group=self.dp_process_group):
 | |
|                 for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
 | |
|                     buf.copy_(synced)
 | |
| 
 | |
|     def allreduce_no_retain(self,
 | |
|                             bucket,
 | |
|                             numel_per_bucket=500000000,
 | |
|                             rank=None,
 | |
|                             log=None):
 | |
|         small_bucket = []
 | |
|         numel = 0
 | |
|         for tensor in bucket:
 | |
|             small_bucket.append(tensor)
 | |
|             numel = numel + tensor.numel()
 | |
|             if numel > numel_per_bucket:
 | |
|                 self.allreduce_and_copy(small_bucket, rank=rank, log=None)
 | |
|                 small_bucket = []
 | |
| 
 | |
|         if len(small_bucket) > 0:
 | |
|             self.allreduce_and_copy(small_bucket, rank=rank, log=log)
 | |
| 
 | |
|     # allows using reduction of gradients instead of using all_reduce
 | |
| 
 | |
|     def buffered_reduce_fallback(self,
 | |
|                                  rank,
 | |
|                                  grads,
 | |
|                                  elements_per_buffer=500000000,
 | |
|                                  log=None):
 | |
|         split_buckets = split_half_float_double(grads)
 | |
| 
 | |
|         for i, bucket in enumerate(split_buckets):
 | |
|             self.allreduce_no_retain(bucket,
 | |
|                                      numel_per_bucket=elements_per_buffer,
 | |
|                                      rank=rank,
 | |
|                                      log=log)
 | |
| 
 | |
|     #############################################################################
 | |
|     #############################################################################
 | |
|     #############################################################################
 | |
| 
 | |
|     # views the tensor as multiple partitions and returns
 | |
|     # those partitions
 | |
|     def get_data_parallel_partitions(self, tensor, group_id):
 | |
|         partitions = []
 | |
| 
 | |
|         dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
 | |
|         dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
 | |
| 
 | |
|         total_num_elements = tensor.numel()
 | |
| 
 | |
|         base_size = total_num_elements // dp
 | |
|         remaining = total_num_elements % dp
 | |
| 
 | |
|         start = 0
 | |
|         for id in range(dp):
 | |
|             partition_size = base_size
 | |
|             if id < remaining:
 | |
|                 partition_size = partition_size + 1
 | |
|             partitions.append(tensor.narrow(0, start, partition_size))
 | |
|             start = start + partition_size
 | |
|         return partitions
 | |
| 
 | |
|     def get_partition_info(self, tensor_list, partition_size, partition_id):
 | |
|         params_in_partition = []
 | |
|         params_not_in_partition = []
 | |
| 
 | |
|         start_index = partition_size * partition_id
 | |
|         end_index = partition_size * (partition_id + 1)
 | |
| 
 | |
|         current_index = 0
 | |
|         first_offset = 0
 | |
| 
 | |
|         for tensor in tensor_list:
 | |
| 
 | |
|             tensor_size = tensor.numel()
 | |
| 
 | |
|             if (current_index >= start_index and current_index < end_index):
 | |
|                 params_in_partition.append(tensor)
 | |
| 
 | |
|             elif start_index > current_index and start_index < (current_index +
 | |
|                                                                 tensor_size):
 | |
|                 params_in_partition.append(tensor)
 | |
| 
 | |
|                 assert (
 | |
|                     first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
 | |
|                 first_offset = start_index - current_index
 | |
| 
 | |
|             else:
 | |
|                 params_not_in_partition.append(tensor)
 | |
| 
 | |
|             current_index = current_index + tensor_size
 | |
| 
 | |
|         return params_in_partition, params_not_in_partition, first_offset
 | |
| 
 | |
|     def zero_grad(self, set_grads_to_None=True):
 | |
|         """
 | |
|         Zero FP16 parameter grads.
 | |
|         """
 | |
|         # FP32 grad should never exist.
 | |
|         # For speed, set model fp16 grad to None by default
 | |
|         for group in self.fp16_groups:
 | |
|             for p in group:
 | |
|                 if set_grads_to_None:
 | |
|                     p.grad = None  # epilogue and in step
 | |
|                 else:
 | |
|                     if p.grad is not None:
 | |
|                         p.grad.detach_()
 | |
|                         p.grad.zero_()
 | |
| 
 | |
|     def _model_parallel_all_reduce(self, tensor, op):
 | |
|         """ Perform all reduce within model parallel group, if any.
 | |
|         """
 | |
|         if self.model_parallel_group is None:
 | |
|             pass
 | |
|         else:
 | |
|             torch.distributed.all_reduce(tensor=tensor,
 | |
|                                          op=op,
 | |
|                                          group=self.model_parallel_group)
 | |
| 
 | |
|     def clip_grad_norm(self, *args, **kwargs):
 | |
|         # dummy function to retain the same function interface
 | |
|         # as ColossalaiOptimizer for compatibility
 | |
|         pass
 | |
| 
 | |
|     def get_grad_norm_direct(self, gradients, params, norm_type=2):
 | |
|         """Clips gradient norm of an iterable of parameters.
 | |
| 
 | |
|         This is adapted from ``torch.nn.utils.clip_grad.clip_grad_norm_`` and
 | |
|         added functionality to handle model parallel parameters. Note that
 | |
|         the gradients are modified in place.
 | |
| 
 | |
|         Arguments:
 | |
|             parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
 | |
|                 single Tensor that will have gradients normalized
 | |
|             max_norm (float or int): max norm of the gradients
 | |
|             norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
 | |
|                 infinity norm.
 | |
| 
 | |
|         Returns:
 | |
|             Total norm of the parameters (viewed as a single vector).
 | |
|         """
 | |
|         norm_type = float(norm_type)
 | |
|         if norm_type == inf:
 | |
|             total_norm = max(g.data.abs().max() for g in gradients)
 | |
|             total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
 | |
|             torch.distributed.all_reduce(total_norm_cuda,
 | |
|                                          op=torch.distributed.ReduceOp.MAX,
 | |
|                                          group=self.dp_process_group)
 | |
| 
 | |
|             # Take max across all GPUs.
 | |
|             self._model_parallel_all_reduce(tensor=total_norm_cuda,
 | |
|                                             op=torch.distributed.ReduceOp.MAX)
 | |
|             total_norm = total_norm_cuda[0].item()
 | |
|         else:
 | |
|             total_norm = 0.0
 | |
|             # if dist.get_rank() == 0:
 | |
|             #    print()(f"Total Norm begining {total_norm}")
 | |
|             for g, p in zip(gradients, params):
 | |
|                 if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
 | |
|                     param_norm = g.data.double().norm(2)
 | |
|                     total_norm += param_norm.item() ** 2
 | |
|             # Sum across all model parallel GPUs.
 | |
|             total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
 | |
| 
 | |
|             torch.distributed.all_reduce(total_norm_cuda,
 | |
|                                          op=torch.distributed.ReduceOp.SUM,
 | |
|                                          group=self.dp_process_group)
 | |
| 
 | |
|             self._model_parallel_all_reduce(tensor=total_norm_cuda,
 | |
|                                             op=torch.distributed.ReduceOp.SUM)
 | |
| 
 | |
|             total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
 | |
| 
 | |
|         if total_norm == float(
 | |
|                 'inf') or total_norm == -float('inf') or total_norm != total_norm:
 | |
|             total_norm = -1
 | |
| 
 | |
|         return total_norm
 | |
| 
 | |
|     # creates a flat fused tensor from the tensor list starting at the first_offset
 | |
|     # in the first tensor of the list. If there are not enough elements in the tensor
 | |
|     # list then the flat tensor will be padded with zeros
 | |
|     def get_flat_partition(self,
 | |
|                            tensor_list,
 | |
|                            first_offset,
 | |
|                            partition_size,
 | |
|                            dtype,
 | |
|                            device,
 | |
|                            return_tensor_list=False):
 | |
|         flat_tensor_list = []
 | |
|         current_size = 0
 | |
|         for i, tensor in enumerate(tensor_list):
 | |
|             if tensor.grad is None:
 | |
|                 tensor.grad = torch.zeros_like(tensor)
 | |
| 
 | |
|             tensor = tensor.grad
 | |
|             num_elements = tensor.numel()
 | |
|             tensor_offset = 0
 | |
| 
 | |
|             # we need to offset to get to the right element
 | |
|             if i == 0 and first_offset > 0:
 | |
|                 tensor_offset = first_offset
 | |
|                 num_elements = num_elements - tensor_offset
 | |
| 
 | |
|             # we dont need all elements of the tensor
 | |
|             if num_elements > (partition_size - current_size):
 | |
|                 num_elements = partition_size - current_size
 | |
| 
 | |
|             # we need a narrow view of the tensor based on the tensor offset and number of elements that
 | |
|             # we need from this tensor
 | |
|             if tensor_offset > 0 or num_elements < tensor.numel():
 | |
|                 flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
 | |
|                     0,
 | |
|                     int(tensor_offset),
 | |
|                     int(num_elements)))
 | |
|             else:
 | |
|                 flat_tensor_list.append(tensor)
 | |
| 
 | |
|             current_size = current_size + num_elements
 | |
| 
 | |
|         # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
 | |
|         if current_size < partition_size:
 | |
|             flat_tensor_list.append(
 | |
|                 torch.zeros(int(partition_size - current_size),
 | |
|                             dtype=dtype,
 | |
|                             device=device))
 | |
| 
 | |
|         if return_tensor_list:
 | |
|             return flat_tensor_list
 | |
| 
 | |
|         return self.flatten(flat_tensor_list)
 | |
| 
 | |
|     def free_grad_in_param_list(self, param_list):
 | |
|         for p in param_list:
 | |
|             p.grad = None  # in step
 | |
| 
 | |
|     def reset_cpu_buffers(self):
 | |
|         self.norm_for_param_grads = {}
 | |
|         self.local_overflow = False
 | |
| 
 | |
|     def log_timers(self, timer_names):
 | |
|         if self.timers is None:
 | |
|             return
 | |
| 
 | |
|         self.timers.log(names=list(timer_names))
 | |
| 
 | |
|     def start_timers(self, timer_names):
 | |
|         if self.timers is None:
 | |
|             return
 | |
| 
 | |
|         for name in timer_names:
 | |
|             self.timers(name).start()
 | |
| 
 | |
|     def stop_timers(self, timer_names):
 | |
|         if self.timers is None:
 | |
|             return
 | |
| 
 | |
|         for name in timer_names:
 | |
|             self.timers(name).stop()
 | |
| 
 | |
|     def step(self, closure=None):
 | |
|         """
 | |
|         Not supporting closure.
 | |
|         """
 | |
|         self.micro_step_id = -1
 | |
| 
 | |
|         if self.verbose:
 | |
|             report_memory_usage(f"In step before checking overflow")
 | |
| 
 | |
|         # First compute norm for all group so we know if there is overflow
 | |
|         self.check_overflow(self.partition_gradients)
 | |
| 
 | |
|         OPTIMIZER_ALLGATHER = 'optimizer_allgather'
 | |
|         OPTIMIZER_GRADIENTS = 'optimizer_gradients'
 | |
|         OPTIMIZER_STEP = 'optimizer_step'
 | |
|         timer_names = [OPTIMIZER_ALLGATHER,
 | |
|                        OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]
 | |
| 
 | |
|         prev_scale = self.loss_scale
 | |
|         self._update_scale(self.overflow)
 | |
|         if self.overflow:
 | |
|             if self.verbose:
 | |
|                 report_memory_usage('After overflow before clearing gradients')
 | |
|             self.zero_grad()
 | |
|             if self.cpu_offload:
 | |
|                 self.reset_cpu_buffers()
 | |
|             else:
 | |
|                 self.averaged_gradients = {}
 | |
| 
 | |
|             if self.verbose:
 | |
|                 report_memory_usage('After overflow after clearing gradients')
 | |
| 
 | |
|             print(
 | |
|                 "[deepspeed] fp16 dynamic loss scale overflow! Rank {} Skipping step. Attempted loss scale: {}, "
 | |
|                 "reducing to {}".format(dist.get_rank(),
 | |
|                                         prev_scale,
 | |
|                                         self.loss_scale))
 | |
|             self.start_timers(timer_names)
 | |
|             self.stop_timers(timer_names)
 | |
|             return
 | |
| 
 | |
|         self.start_timers([OPTIMIZER_GRADIENTS])
 | |
|         norm_groups = []
 | |
|         single_partition_grad_groups = []
 | |
|         skip = False
 | |
|         for i, group in enumerate(self.fp16_groups):
 | |
|             partition_id = dist.get_rank(group=self.real_dp_process_group[i])
 | |
|             if self.cpu_offload:
 | |
|                 norm_groups.append(
 | |
|                     self.complete_grad_norm_calculation_for_cpu_offload(
 | |
|                         self.params_in_partition[i]))
 | |
|                 single_grad_partition = self.single_partition_of_fp32_groups[i].grad
 | |
|             else:
 | |
|                 norm_groups.append(
 | |
|                     self.get_grad_norm_direct(self.averaged_gradients[i],
 | |
|                                               self.params_in_partition[i]))
 | |
| 
 | |
|                 # free gradients for all the prameters that are not updated by this process
 | |
|                 self.free_grad_in_param_list(self.params_not_in_partition[i])
 | |
| 
 | |
|                 # create a flat gradients for parameters updated by this process
 | |
|                 # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
 | |
|                 if partition_id == dist.get_world_size(
 | |
|                         group=self.real_dp_process_group[i]) - 1:
 | |
|                     single_grad_partition = self.flatten_dense_tensors_aligned(
 | |
|                         self.averaged_gradients[i],
 | |
|                         int(self.partition_size[i])).to(
 | |
|                         self.single_partition_of_fp32_groups[i].dtype)
 | |
|                 else:
 | |
|                     single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
 | |
|                         self.single_partition_of_fp32_groups[i].dtype)
 | |
|                 assert single_grad_partition.numel() == self.partition_size[i], \
 | |
|                     "averaged gradients have different number of elements that partition size {} {} {} {}".format(
 | |
|                         single_grad_partition.numel(), self.partition_size[i], i, partition_id)
 | |
| 
 | |
|                 self.single_partition_of_fp32_groups[i].grad = single_grad_partition
 | |
|                 # release all the gradient since we have already created a necessary copy in dp_grad_partition
 | |
|                 self.free_grad_in_param_list(self.params_in_partition[i])
 | |
| 
 | |
|                 self.averaged_gradients[i] = None
 | |
| 
 | |
|             single_partition_grad_groups.append(single_grad_partition)
 | |
| 
 | |
|         if self.has_moe_layers:
 | |
|             self._average_expert_grad_norms(norm_groups)
 | |
| 
 | |
|         self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
 | |
|         self.stop_timers([OPTIMIZER_GRADIENTS])
 | |
| 
 | |
|         self.start_timers([OPTIMIZER_STEP])
 | |
|         if self.deepspeed_adam_offload:
 | |
|             from deepspeed.ops.adam import DeepSpeedCPUAdam
 | |
|             if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
 | |
|                 fp16_param_groups = [
 | |
|                     fp16_partitions[partition_id]
 | |
|                     for fp16_partitions in self.parallel_partitioned_fp16_groups
 | |
|                 ]
 | |
|                 self.optimizer.step(fp16_param_groups=fp16_param_groups)
 | |
|             else:
 | |
|                 self.optimizer.step()
 | |
|                 for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
 | |
|                                                            self.single_partition_of_fp32_groups):
 | |
|                     fp16_partitions[partition_id].data.copy_(
 | |
|                         fp32_partition.data)
 | |
|         else:
 | |
|             self.optimizer.step()
 | |
| 
 | |
|             # get rid of the fp32 gradients. Not needed anymore
 | |
|             if not self.cpu_offload:
 | |
|                 for group in self.single_partition_of_fp32_groups:
 | |
|                     group.grad = None  # in step
 | |
| 
 | |
|             for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
 | |
|                                                        self.single_partition_of_fp32_groups):
 | |
|                 fp16_partitions[partition_id].data.copy_(fp32_partition.data)
 | |
| 
 | |
|         self.stop_timers([OPTIMIZER_STEP])
 | |
| 
 | |
|         if self.cpu_offload:
 | |
|             self.reset_cpu_buffers()
 | |
| 
 | |
|         self.start_timers([OPTIMIZER_ALLGATHER])
 | |
|         # gather the updated weights from everyone
 | |
|         for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
 | |
| 
 | |
|             # Sequential AllGather Best of both worlds
 | |
|             dp_world_size = dist.get_world_size(
 | |
|                 group=self.real_dp_process_group[group_id])
 | |
|             num_shards = max(
 | |
|                 1,
 | |
|                 partitioned_params[partition_id].numel() * dp_world_size //
 | |
|                 self.allgather_bucket_size)
 | |
| 
 | |
|             shard_size = partitioned_params[partition_id].numel() // num_shards
 | |
|             num_elements = shard_size
 | |
| 
 | |
|             assert shard_size * \
 | |
|                 num_shards <= partitioned_params[partition_id].numel()
 | |
| 
 | |
|             for shard_id in range(num_shards):
 | |
| 
 | |
|                 if shard_id == (num_shards - 1):
 | |
|                     num_elements = partitioned_params[partition_id].numel(
 | |
|                     ) - shard_id * shard_size
 | |
| 
 | |
|                 shard_list = []
 | |
|                 for dp_id in range(dp_world_size):
 | |
|                     curr_shard = partitioned_params[dp_id].narrow(
 | |
|                         0,
 | |
|                         shard_id * shard_size,
 | |
|                         num_elements).detach()
 | |
|                     shard_list.append(curr_shard)
 | |
| 
 | |
|                 dist.all_gather(shard_list,
 | |
|                                 shard_list[partition_id],
 | |
|                                 group=self.real_dp_process_group[group_id])
 | |
|         self.stop_timers([OPTIMIZER_ALLGATHER])
 | |
| 
 | |
|         # TODO: we probably don't need this? just to be safe
 | |
|         for i in range(len(norm_groups)):
 | |
|             self._update_model_fp16_weights(i)
 | |
| 
 | |
|         self.log_timers(timer_names)
 | |
|         if self.verbose:
 | |
|             report_memory_usage('After zero_optimizer step')
 | |
| 
 | |
|         return
 | |
| 
 | |
|     def _average_expert_grad_norms(self, norm_groups):
 | |
|         for i, norm in enumerate(norm_groups):
 | |
|             if self.is_moe_param_group[i]:
 | |
|                 scaled_norm = norm * 1.0 / float(
 | |
|                     dist.get_world_size(group=self.ep_process_group))
 | |
|                 scaled_norm_tensor = torch.tensor(scaled_norm,
 | |
|                                                   device='cuda',
 | |
|                                                   dtype=torch.float)
 | |
|                 dist.all_reduce(scaled_norm_tensor,
 | |
|                                 group=self.ep_process_group)
 | |
|                 norm_groups[i] = scaled_norm_tensor.item()
 | |
| 
 | |
|     def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
 | |
|         total_norm = 0.0
 | |
|         for norm in norm_groups:
 | |
|             total_norm += norm ** 2.0
 | |
|         total_norm = math.sqrt(total_norm)
 | |
| 
 | |
|         # compute combined scale factor for this group
 | |
|         combined_scale = self.loss_scale
 | |
|         if self.clip_grad > 0.:
 | |
|             # norm is in fact norm*scale
 | |
|             clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
 | |
|             if clip > 1:
 | |
|                 combined_scale = clip * self.loss_scale
 | |
| 
 | |
|         for grad in grad_groups_flat:
 | |
|             if isinstance(grad, list):
 | |
|                 sub_partitions = grad
 | |
|                 for g in sub_partitions:
 | |
|                     g.data.mul_(1. / combined_scale)
 | |
|             else:
 | |
|                 grad.data.mul_(1. / combined_scale)
 | |
| 
 | |
|     def _check_overflow(self, partition_gradients=True):
 | |
|         self.overflow = self.has_overflow(partition_gradients)
 | |
| 
 | |
|     # `params` is a list / generator of torch.Variable
 | |
|     def has_overflow_serial(self, params, is_grad_list=False):
 | |
|         for p in params:
 | |
|             if p.grad is not None and self._has_inf_or_nan(p.grad.data):
 | |
|                 return True
 | |
| 
 | |
|         return False
 | |
| 
 | |
|     def has_overflow_partitioned_grads_serial(self):
 | |
|         for i in range(len(self.fp16_groups)):
 | |
|             for j, grad in enumerate(self.averaged_gradients[i]):
 | |
|                 if grad is not None and self._has_inf_or_nan(grad.data, j):
 | |
|                     return True
 | |
|         return False
 | |
| 
 | |
|     def has_overflow(self, partition_gradients=True):
 | |
|         if partition_gradients:
 | |
|             overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
 | |
|             )
 | |
|             overflow_gpu = torch.cuda.ByteTensor([overflow])
 | |
|             '''This will capture overflow across all data parallel and expert parallel process
 | |
|             Since expert parallel process are a subset of data parallel process'''
 | |
|             torch.distributed.all_reduce(overflow_gpu,
 | |
|                                          op=torch.distributed.ReduceOp.MAX,
 | |
|                                          group=self.dp_process_group)
 | |
| 
 | |
|         else:
 | |
|             params = []
 | |
|             for group in self.fp16_groups:
 | |
|                 for param in group:
 | |
|                     params.append(param)
 | |
| 
 | |
|             overflow = self.has_overflow_serial(
 | |
|                 params, is_grad_list=partition_gradients)
 | |
|             overflow_gpu = torch.cuda.ByteTensor([overflow])
 | |
| 
 | |
|         # Since each model parallel GPU carries only part of the model,
 | |
|         # make sure overflow flag is synced across all the model parallel GPUs
 | |
|         self._model_parallel_all_reduce(tensor=overflow_gpu,
 | |
|                                         op=torch.distributed.ReduceOp.MAX)
 | |
| 
 | |
|         overflow = overflow_gpu[0].item()
 | |
|         return bool(overflow)
 | |
| 
 | |
|     # `x` is a torch.Tensor
 | |
|     @staticmethod
 | |
|     def _has_inf_or_nan(x, j=None):
 | |
|         try:
 | |
|             # if x is half, the .float() incurs an additional deep copy, but it's necessary if
 | |
|             # Pytorch's .sum() creates a one-element tensor of the same type as x
 | |
|             # (which is true for some recent version of pytorch).
 | |
|             cpu_sum = float(x.float().sum())
 | |
|             # More efficient version that can be used if .sum() returns a Python scalar
 | |
|             # cpu_sum = float(x.sum())
 | |
|         except RuntimeError as instance:
 | |
|             # We want to check if inst is actually an overflow exception.
 | |
|             # RuntimeError could come from a different error.
 | |
|             # If so, we still want the exception to propagate.
 | |
|             if "value cannot be converted" not in instance.args[0]:
 | |
|                 raise
 | |
|             return True
 | |
|         else:
 | |
|             if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
 | |
|                 return True
 | |
|             return False
 | |
| 
 | |
|     def backward(self, loss, retain_graph=False):
 | |
|         """
 | |
|         :attr:`backward` performs the following steps:
 | |
| 
 | |
|         1. fp32_loss = loss.float()
 | |
|         2. scaled_loss = fp32_loss*loss_scale
 | |
|         3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
 | |
|         """
 | |
|         self.micro_step_id += 1
 | |
| 
 | |
|         if self.contiguous_gradients:
 | |
|             self.ipg_buffer = []
 | |
|             buf_0 = torch.empty(int(self.reduce_bucket_size),
 | |
|                                 dtype=self.dtype,
 | |
|                                 device=torch.cuda.current_device())
 | |
|             self.ipg_buffer.append(buf_0)
 | |
| 
 | |
|             # Use double buffers to avoid data access conflict when overlap_comm is enabled.
 | |
|             if self.overlap_comm:
 | |
|                 buf_1 = torch.empty(int(self.reduce_bucket_size),
 | |
|                                     dtype=self.dtype,
 | |
|                                     device=torch.cuda.current_device())
 | |
|                 self.ipg_buffer.append(buf_1)
 | |
|             self.ipg_index = 0
 | |
| 
 | |
|         self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
 | |
| 
 | |
|     def check_overflow(self, partition_gradients=True):
 | |
|         self._check_overflow(partition_gradients)
 | |
| 
 | |
|     def _update_scale(self, has_overflow=False):
 | |
|         self.loss_scaler.update_scale(has_overflow)
 | |
| 
 | |
|     # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
 | |
|     def _get_state(self):
 | |
|         return self.optimizer.state
 | |
| 
 | |
|     def _set_state(self, value):
 | |
|         self.optimizer.state = value
 | |
| 
 | |
|     state = property(_get_state, _set_state)
 | |
| 
 | |
|     # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
 | |
|     # (for example, to adjust the learning rate)
 | |
|     def _get_param_groups(self):
 | |
|         return self.optimizer.param_groups
 | |
| 
 | |
|     def _set_param_groups(self, value):
 | |
|         self.optimizer.param_groups = value
 | |
| 
 | |
|     param_groups = property(_get_param_groups, _set_param_groups)
 | |
| 
 | |
|     # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
 | |
|     def _get_loss_scale(self):
 | |
|         return self.loss_scaler.loss_scale
 | |
| 
 | |
|     def _set_loss_scale(self, value):
 | |
|         self.loss_scaler.cur_scale = value
 | |
| 
 | |
|     loss_scale = property(_get_loss_scale, _set_loss_scale)
 | |
|     cur_scale = property(_get_loss_scale, _set_loss_scale)
 | |
| 
 | |
|     # Return group tensor after removing paddings that are added for alignment to DP world size.
 | |
|     # This method works on the assumption that each group contains a single flattened tensor.
 | |
|     def _get_groups_without_padding(self, groups_with_padding):
 | |
|         groups_without_padding = []
 | |
|         for i, group in enumerate(groups_with_padding):
 | |
|             lean_length = group.numel() - self.groups_padding[i]
 | |
|             groups_without_padding.append(group[:lean_length])
 | |
| 
 | |
|         return groups_without_padding
 | |
| 
 | |
|     # Return optimizer state after removing paddings that are added for alignment.
 | |
|     def _get_state_without_padding(self, state_with_padding, padding):
 | |
|         lean_state = {}
 | |
|         for key, value in state_with_padding.items():
 | |
|             if torch.is_tensor(value):
 | |
|                 lean_length = value.numel() - padding
 | |
|                 lean_state[key] = value[:lean_length]
 | |
|             else:
 | |
|                 lean_state[key] = value
 | |
| 
 | |
|         return lean_state
 | |
| 
 | |
|     # Return base optimizer states.
 | |
|     # This method assumes that each param group contains a single flattened tensor.
 | |
|     def _get_base_optimizer_state(self):
 | |
|         optimizer_groups_state = []
 | |
|         for i, group in enumerate(self.optimizer.param_groups):
 | |
|             p = group['params'][0]
 | |
|             lean_optimizer_state = self._get_state_without_padding(
 | |
|                 self.optimizer.state[p],
 | |
|                 self.groups_padding[i])
 | |
|             optimizer_groups_state.append(lean_optimizer_state)
 | |
| 
 | |
|         return optimizer_groups_state
 | |
| 
 | |
|     def state_dict(self):
 | |
|         """
 | |
|         Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
 | |
|         This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
 | |
|         of the contained Pytorch optimizer.
 | |
| 
 | |
|         Example::
 | |
| 
 | |
|             checkpoint = {}
 | |
|             checkpoint['model'] = model.state_dict()
 | |
|             checkpoint['optimizer'] = optimizer.state_dict()
 | |
|             torch.save(checkpoint, "saved.pth")
 | |
|         """
 | |
|         state_dict = {}
 | |
|         state_dict['loss_scaler'] = self.loss_scaler
 | |
|         state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
 | |
|         state_dict['overflow'] = self.overflow
 | |
|         state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
 | |
| 
 | |
|         state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
 | |
|         state_dict['partition_count'] = self.partition_count
 | |
| 
 | |
|         state_dict['ds_version'] = version
 | |
| 
 | |
|         # Remove paddings for DP alignment to enable loading for other alignment values
 | |
|         fp32_groups_without_padding = self._get_groups_without_padding(
 | |
|             self.single_partition_of_fp32_groups)
 | |
|         state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
 | |
| 
 | |
|         #        if self.cpu_offload:
 | |
|         #            state_dict_tmp = async_copy_to(state_dict,
 | |
|         #                                           'cpu',
 | |
|         #                                           torch.cuda.current_stream())
 | |
|         #            state_dict = state_dict_tmp
 | |
| 
 | |
|         return state_dict
 | |
| 
 | |
|     # Restore base optimizer fp32 weights from checkpoint by:
 | |
|     # 1) Merging fp32 weights from checkpoints of all partitions
 | |
|     # 2) Extracting fp32 weights for current partition from merged weights
 | |
|     # 3) Using extracted weights to update base optimizer weights directly.
 | |
|     def _restore_from_fp32_weights(self, all_state_dict):
 | |
|         merged_single_partition_of_fp32_groups = []
 | |
|         for i in range(len(self.single_partition_of_fp32_groups)):
 | |
|             partition_id = dist.get_rank(group=self.real_dp_process_group[i])
 | |
|             merged_partitions = [
 | |
|                 sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
 | |
|             ]
 | |
|             flat_merged_partitions = self.flatten_dense_tensors_aligned(
 | |
|                 merged_partitions,
 | |
|                 self.nccl_start_alignment_factor *
 | |
|                 dist.get_world_size(group=self.real_dp_process_group[i]))
 | |
|             dp_partitions = self.get_data_parallel_partitions(
 | |
|                 flat_merged_partitions, i)
 | |
|             merged_single_partition_of_fp32_groups.append(
 | |
|                 dp_partitions[partition_id])
 | |
| 
 | |
|         for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
 | |
|             current.data.copy_(saved.data)
 | |
| 
 | |
|     # Restore base optimizer fp32 weights from ZeRO fp16 weights
 | |
|     def _restore_from_fp16_weights(self):
 | |
|         for group_id, fp16_partitions, fp32_partition in enumerate(
 | |
|                 zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups)):
 | |
|             partition_id = dist.get_rank(
 | |
|                 group=self.real_dp_process_group[group_id])
 | |
|             fp32_partition.data.copy_(fp16_partitions[partition_id].data)
 | |
| 
 | |
|     # Refresh the fp32 master params from the fp16 copies.
 | |
|     def refresh_fp32_params(self):
 | |
|         self._restore_from_fp16_weights()
 | |
| 
 | |
|     # Extract optimizer state for current partition from merged states of all partitions
 | |
|     def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):
 | |
|         partition_id = dist.get_rank(
 | |
|             group=self.real_dp_process_group[group_id])
 | |
|         alignment = dist.get_world_size(
 | |
|             group=self.real_dp_process_group[group_id])
 | |
|         if torch.is_tensor(all_partition_states[0]):
 | |
|             flat_merged_partitions = self.flatten_dense_tensors_aligned(
 | |
|                 all_partition_states,
 | |
|                 alignment)
 | |
|             dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions,
 | |
|                                                               group_id)
 | |
|             return dp_partitions[partition_id]
 | |
|         else:
 | |
|             # Assume non-tensor states are not partitioned and equal across ranks, so return first one
 | |
|             return all_partition_states[0]
 | |
| 
 | |
|     # Restore base optimizer state from checkpoint by
 | |
|     # 1) Merging optimizer state from checkpoints of all partitions
 | |
|     # 2) Extracting optimizer state for current partition from the merged state
 | |
|     # 3) Using the extracted value to directly update the base optimizer.
 | |
|     def _restore_base_optimizer_state(self, all_state_dict):
 | |
|         base_optimizer_group_states = []
 | |
|         for i in range(len(self.optimizer.param_groups)):
 | |
|             partition_states = {}
 | |
|             all_partition_group_states = [
 | |
|                 sd['base_optimizer_state'][i] for sd in all_state_dict
 | |
|             ]
 | |
|             for key in all_partition_group_states[0].keys():
 | |
|                 all_partition_states = [
 | |
|                     all_states[key] for all_states in all_partition_group_states
 | |
|                 ]
 | |
|                 partition_states[key] = self._partition_base_optimizer_state(
 | |
|                     key,
 | |
|                     all_partition_states,
 | |
|                     i)
 | |
|             base_optimizer_group_states.append(partition_states)
 | |
| 
 | |
|         for i, group in enumerate(self.optimizer.param_groups):
 | |
|             p = group['params'][0]
 | |
|             for key, saved in base_optimizer_group_states[i].items():
 | |
|                 if torch.is_tensor(self.optimizer.state[p][key]):
 | |
|                     self.optimizer.state[p][key].data.copy_(saved.data)
 | |
|                 else:
 | |
|                     self.optimizer.state[p][key] = saved
 | |
| 
 | |
|     def load_state_dict(self,
 | |
|                         state_dict_list,
 | |
|                         load_optimizer_states=True,
 | |
|                         load_from_fp32_weights=False):
 | |
|         r"""Loading ZeRO checkpoint
 | |
| 
 | |
|         Arguments:
 | |
|             state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
 | |
|                 Note that the number of saved partitions may differ from number of loading partitions to support
 | |
|                 changing GPU count, specifically DP world size, between saving and loading checkpoints.
 | |
|             load_optimizer_states: Boolean indicating whether or not to load base optimizer states
 | |
|             load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
 | |
|             copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
 | |
|         """
 | |
|         """
 | |
|         Loads a state_dict created by an earlier call to state_dict().
 | |
|         If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
 | |
|         whose parameters in turn came from ``model``, it is expected that the user
 | |
|         will call ``model.load_state_dict()`` before
 | |
|         ``fp16_optimizer_instance.load_state_dict()`` is called.
 | |
|         
 | |
|         Example::
 | |
| 
 | |
|             model = torch.nn.Linear(D_in, D_out).cuda().half()
 | |
|             optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 | |
|             optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
 | |
|             ...
 | |
|             checkpoint = torch.load("saved.pth")
 | |
|             model.load_state_dict(checkpoint['model'])
 | |
|             optimizer.load_state_dict(checkpoint['optimizer'])
 | |
|         """
 | |
|         # I think it should actually be ok to reload the optimizer before the model.
 | |
|         self.loss_scaler = state_dict_list[0]['loss_scaler']
 | |
|         self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
 | |
|         self.overflow = state_dict_list[0]['overflow']
 | |
| 
 | |
|         # zero stage 1 mode
 | |
|         if not self.partition_gradients:
 | |
|             required_version = pkg_version.parse("0.3.17")
 | |
|             ckpt_version = state_dict_list[0].get("ds_version", False)
 | |
|             error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
 | |
|                         "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
 | |
|                         "please set 'legacy_stage1': true in your zero config json. This old version of " \
 | |
|                         "stage 1 will be removed in v0.4.0."
 | |
| 
 | |
|             assert ckpt_version, f"Empty ds_version! {error_str}"
 | |
|             assert required_version <= pkg_version.parse(
 | |
|                 ckpt_version), f"Old version: {ckpt_version} {error_str}"
 | |
| 
 | |
|         if load_optimizer_states:
 | |
|             self._restore_base_optimizer_state(state_dict_list)
 | |
| 
 | |
|         # At this point, the optimizer's references to the model's fp32 parameters are up to date.
 | |
|         # The optimizer's hyperparameters and internal buffers are also up to date.
 | |
|         # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
 | |
|         # out of date.  There are two options.
 | |
|         # 1:  Refresh the master params from the model's fp16 params.
 | |
|         # This requires less storage but incurs precision loss.
 | |
|         # 2:  Save and restore the fp32 master copies separately.
 | |
|         # We choose option 1 if changing DP degree and option 2 otherwise.
 | |
|         #
 | |
|         # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
 | |
|         # of their associated parameters, because it's possible those buffers might not exist yet in
 | |
|         # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been
 | |
|         # constructed in the same way as the one whose state_dict we are loading, the same master params
 | |
|         # are guaranteed to exist, so we can just copy_() from the saved master params.
 | |
| 
 | |
|         if load_from_fp32_weights:
 | |
|             self._restore_from_fp32_weights(state_dict_list)
 | |
|         else:
 | |
|             self._restore_from_fp16_weights()
 | |
| 
 | |
|     def allreduce_gradients(self):
 | |
|         self.overlapping_partition_gradients_reduce_epilogue()
 | |
| 
 | |
| 
 | |
| def _handle_overflow(cpu_sum, x, i):
 | |
|     import math
 | |
|     rank = torch.distributed.get_rank()
 | |
|     if rank == 0:
 | |
|         t_i = -1
 | |
|         for v_i, v in enumerate(x.data.contiguous().view(-1)):
 | |
|             if not math.isfinite(float(v)):
 | |
|                 t_i = v_i
 | |
|                 break
 | |
|         print(
 | |
|             f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| def estimate_zero2_model_states_mem_needs(total_params,
 | |
|                                           num_gpus_per_node=1,
 | |
|                                           num_nodes=1,
 | |
|                                           cpu_offload=True,
 | |
|                                           additional_buffer_factor=1.5):
 | |
|     total_gpus = num_nodes * num_gpus_per_node
 | |
| 
 | |
|     if cpu_offload:
 | |
|         gpu_mem = 2 * total_params
 | |
|         cpu_mem = total_params * \
 | |
|             max(4 * total_gpus, 16) * additional_buffer_factor
 | |
|     else:
 | |
|         gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
 | |
|         cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
 | |
| 
 | |
|     return int(cpu_mem), int(gpu_mem)
 | |
| 
 | |
| 
 | |
| def model_to_params(model):
 | |
|     # shared params calculated only once
 | |
|     total_params = sum(
 | |
|         dict((p.data_ptr(),
 | |
|               p.numel()) for p in model.parameters()).values())
 | |
|     return total_params
 | |
| 
 | |
| 
 | |
| def estimate_zero2_model_states_mem_needs_all_live(model,
 | |
|                                                    num_gpus_per_node=1,
 | |
|                                                    num_nodes=1,
 | |
|                                                    additional_buffer_factor=1.5):
 | |
|     """
 | |
|     Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
 | |
|     for a given ``model`` and hardware setup.
 | |
| 
 | |
|     If you have an actual model object, use this function and everything will be derived
 | |
|     automatically.
 | |
| 
 | |
|     If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
 | |
|     the ``total_params`` explicitly.
 | |
| 
 | |
|     Args:
 | |
|         - ``model``: ``nn.Module`` object
 | |
|         - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
 | |
|         - ``num_nodes``: how many nodes (defaults to 1),
 | |
|         - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
 | |
| 
 | |
|     """
 | |
| 
 | |
|     total_params = model_to_params(model)
 | |
| 
 | |
|     estimate_zero2_model_states_mem_needs_all_cold(
 | |
|         total_params=total_params,
 | |
|         num_gpus_per_node=num_gpus_per_node,
 | |
|         num_nodes=num_nodes,
 | |
|         additional_buffer_factor=additional_buffer_factor)
 | |
| 
 | |
| 
 | |
| def estimate_zero2_model_states_mem_needs_all_cold(total_params,
 | |
|                                                    num_gpus_per_node=1,
 | |
|                                                    num_nodes=1,
 | |
|                                                    additional_buffer_factor=1.5):
 | |
|     """
 | |
|     Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
 | |
|     for a given ``model`` and hardware setup.
 | |
| 
 | |
|     If it's a hypothetical model, use this function where you have to pass
 | |
|     the ``total_params`` and ``largest_layer_params`` explicitly.
 | |
| 
 | |
|     If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
 | |
|     will be derived automatically.
 | |
| 
 | |
|     Args:
 | |
|         - ``total_params``: total  model params
 | |
|         - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
 | |
|         - ``num_nodes``: how many nodes (defaults to 1),
 | |
|         - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
 | |
| 
 | |
|     """
 | |
| 
 | |
|     def format_options(cpu_offload):
 | |
|         enabled = []
 | |
|         enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
 | |
|         return ", ".join(enabled)
 | |
| 
 | |
|     nodes_str = "nodes" if num_nodes > 1 else "node"
 | |
|     gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
 | |
|     print(
 | |
|         "Estimated memory needed for params, optim states and gradients for a:\n"
 | |
|         f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
 | |
|         f"SW: Model with {int(total_params / 1e6)}M total params.")
 | |
|     print("  per CPU  |  per GPU |   Options")
 | |
|     for cpu_offload in [True, False]:
 | |
|         cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(
 | |
|             total_params=total_params,
 | |
|             num_gpus_per_node=num_gpus_per_node,
 | |
|             num_nodes=num_nodes,
 | |
|             cpu_offload=cpu_offload,
 | |
|             additional_buffer_factor=additional_buffer_factor
 | |
|         )
 | |
| 
 | |
|         options_str = format_options(cpu_offload=cpu_offload)
 | |
|         print(
 | |
|             f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}")
 |