From 77ec7733884b01dfec4e1958b2e9b1ba5d5036ec Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 7 May 2024 12:01:38 +0800 Subject: [PATCH] [zero]remove registered gradients hooks (#5687) * remove registered hooks fix fix fix zero fix fix fix fix fix zero fix zero fix fix fix * fix fix fix --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +- .../zero/low_level/bookkeeping/base_store.py | 1 + .../low_level/bookkeeping/bucket_store.py | 31 +- .../low_level/bookkeeping/gradient_store.py | 9 +- colossalai/zero/low_level/low_level_optim.py | 369 ++++++++++-------- .../test_plugin/test_low_level_zero_plugin.py | 1 - .../test_model/test_shard_llama.py | 4 +- 7 files changed, 256 insertions(+), 167 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 97057481e..14d9935f3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -735,7 +735,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: + if self._grad_store.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: @@ -759,7 +759,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward method to compute gradients. super().backward(loss, retain_graph) - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward_by_grad method to compute gradients. super().backward_by_grad(tensor, grad) - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -1272,7 +1272,7 @@ class HybridParallelPlugin(PipelinePluginBase): # run with gradients accumulation if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False ): return outputs diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 107d62dcb..7f2f9664b 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -6,6 +6,7 @@ class BaseStore: def __init__(self, torch_pg: ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg) + self.torch_pg = torch_pg @property def world_size(self): diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2ebc704f7..1496603fa 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,16 +1,43 @@ -from typing import Dict +from typing import Dict, Optional import torch +import torch.distributed as dist from torch import Tensor from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator + from .base_store import BaseStore class BucketStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): + def __init__( + self, + torch_pg: ProcessGroup, + reduce_bucket_size: int, + overlap_communication: bool, + communication_dtype: Optional[torch.dtype] = None, + moe_extra_dp_process_group: ProcessGroup = None, + ): super().__init__(torch_pg) + self.reduce_bucket_size = reduce_bucket_size + # communication params + self._overlap_communication = overlap_communication + self._communication_dtype = communication_dtype + if self._overlap_communication: + self.comm_stream = get_accelerator().Stream() + self.zero_local_rank = dist.get_rank(group=self.torch_pg) + self.zero_world_size = dist.get_world_size(group=self.torch_pg) + # extra dp + # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. + # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. + # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. + # And moe working and master param are split by extra dp pg. + self.moe_extra_dp_pg = moe_extra_dp_process_group + if self.moe_extra_dp_pg is not None: + self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) + self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) self.reset_all() def reset_all(self) -> None: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 6d4fcbb86..fc28b7795 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -6,7 +6,7 @@ from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False): + def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True): super().__init__(*args) """ self._grads_of_params mapping the parameter and its gradient slices @@ -18,9 +18,12 @@ class GradientStore(BaseStore): } """ self._grads_of_params = dict() - # for zero2, it's `param_id: [grad_local_rank]` + # stage 2 + self._partition_grads = partition_grad + # grad accumulation + self.require_grad_sync = require_grad_sync self._working_index = 0 if partition_grad else self._local_rank - + # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 345dfde73..1b856cafd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -90,38 +90,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._logger = get_dist_logger() self._verbose = verbose - # stage 2 - self._partition_grads = partition_grad - self._cpu_offload = cpu_offload - # grad accumulation - self.require_grad_sync = True - - # if process_group is none, will use the default one - self.dp_pg = dp_process_group - self._local_rank = dist.get_rank(group=self.dp_pg) - self._world_size = dist.get_world_size(group=self.dp_pg) - - # extra dp - # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. - # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. - # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. - # And moe working and master param are split by extra dp pg. - self.moe_extra_dp_pg = moe_extra_dp_process_group - if self.moe_extra_dp_pg is not None: - self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) - self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() - # communication params - self._overlap_communication = overlap_communication - self._reduce_bucket_size = reduce_bucket_size - self._communication_dtype = communication_dtype - # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -140,9 +114,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self.dp_pg) - self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) - self._bucket_store = BucketStore(self.dp_pg) + self._param_store = ParameterStore(dp_process_group) + self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True) + self._bucket_store = BucketStore( + dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group + ) # moe param should not be stored in working_groups # because they have different parallel strategy @@ -157,7 +133,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): group_params = list() for param in param_group["params"]: if param.requires_grad: - if self.moe_extra_dp_pg is None: + if self._bucket_store.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): self.working_moe_params.append(param) @@ -194,15 +170,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) - # initialize communication stream for - # communication-computation overlapping - if self._overlap_communication: - self._comm_stream = get_accelerator().Stream() - # reduction hook is only used if overlapping communication # or stage 2 is used # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_communication or self._partition_grads: + if self._bucket_store._overlap_communication or self._grad_store._partition_grads: self._attach_reduction_hook() # initialize mixed precision mixin @@ -222,6 +193,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() + def __del__(self): + self.remove_hooks() + @property def dtype(self): return self._dtype @@ -246,7 +220,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: - padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + padding_size = ( + self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size + ) % self._bucket_store.zero_world_size self._param_store.record_param_padding_size(param, padding_size) with torch.no_grad(): @@ -258,12 +234,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: padding_param = param.data.view(-1) - if self.moe_extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) - splited_params = splited_params[self.moe_extra_dp_pg_rank] + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split( + padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size + ) + splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank] else: - splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_params = splited_params[self._local_rank] + splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size) + splited_params = splited_params[self._bucket_store.zero_local_rank] # use fp32 when master_weights is True if self._master_weights is True: @@ -280,10 +258,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Backward Reduction Hook # ########################### - def _grad_handler(self, group_id, param): + @staticmethod + def grad_handler( + param: nn.Parameter, + group_id: int, + bucket_store: BucketStore, + param_store: ParameterStore, + grad_store: GradientStore, + ): # if run with no_sync context, would not sync grad when backward - if self.require_grad_sync: - self._add_to_bucket(param, group_id) + if grad_store.require_grad_sync: + LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store) def _attach_reduction_hook(self): # we iterate over the working params @@ -292,29 +277,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) + param._grad_handle = param.register_post_accumulate_grad_hook( + partial( + LowLevelZeroOptimizer.grad_handler, + group_id=group_id, + bucket_store=self._bucket_store, + param_store=self._param_store, + grad_store=self._grad_store, + ) + ) ####################### # Reduction Functions # ####################### - - def _run_reduction(self): - if self._bucket_store.num_elements_in_bucket() > 0: - self._bucket_store.build_grad_in_bucket() - - if self.moe_extra_dp_pg is None: - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + @staticmethod + def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): + if bucket_store.num_elements_in_bucket() > 0: + bucket_store.build_grad_in_bucket() + if bucket_store.moe_extra_dp_pg is None: + flat_grads = bucket_store.get_flatten_grad() + flat_grads /= bucket_store.zero_world_size else: # record moe and non moe param moe_list = [] - for param in self._bucket_store._param_list: + for param in bucket_store._param_list: moe_list.append(is_moe_tensor(param)) # divide them into different groups moe_grad_list = [] non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): + for grad_list in bucket_store._grad_in_bucket.values(): non_moe_cur_grad = [] moe_cur_grad = [] for i in range(len(grad_list)): @@ -332,7 +324,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for grad_list in non_moe_grad_list: non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= self._world_size + non_moe_flat_grads /= bucket_store.zero_world_size if len(moe_grad_list) > 0: moe_flat_grads = [] @@ -341,12 +333,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket - self._bucket_store.reset_num_elements_in_bucket() + bucket_store.reset_num_elements_in_bucket() - if self._overlap_communication: - stream = self._comm_stream + if bucket_store._overlap_communication: + stream = bucket_store.comm_stream # in case of the memory being reused in the default stream - if self.moe_extra_dp_pg is None: + if bucket_store.moe_extra_dp_pg is None: flat_grads.record_stream(stream) else: if len(non_moe_grad_list) > 0: @@ -359,53 +351,63 @@ class LowLevelZeroOptimizer(OptimizerWrapper): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - group_id = self._bucket_store.current_group_id + group_id = bucket_store.current_group_id - if self.moe_extra_dp_pg is None: + if bucket_store.moe_extra_dp_pg is None: grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + if bucket_store._communication_dtype is not None: + flat_grads = flat_grads.to(bucket_store._communication_dtype) - if not self._partition_grads: - if self.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=self.dp_pg) + if not grad_store._partition_grads: + if bucket_store.moe_extra_dp_pg is None: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size) + grad_in_bucket = bucket_store.get_grad() + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id + ) # sync extra zero group else: # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg) flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // self._world_size + non_moe_flat_grads.numel() // bucket_store.zero_world_size + ) + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id ) - self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) - flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) - self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) + dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split( + moe_flat_grads.numel() // bucket_store.zero_world_size + ) + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id + ) else: - if self.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + if bucket_store.moe_extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) if recieved_grad.dtype != grad_dtype: recieved_grad = recieved_grad.to(grad_dtype) - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] + LowLevelZeroOptimizer.update_partitoned_grad( + bucket_store, grad_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1 + ) else: # categorize moe and non moe param - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] moe_grad_in_bucket_current_rank = [] non_moe_grad_in_bucket_current_rank = [] for idx, grad in enumerate(grad_in_bucket_current_rank): @@ -416,11 +418,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(non_moe_grad_list) > 0: flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) + non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - self._update_partitoned_grad( + dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + LowLevelZeroOptimizer.update_partitoned_grad( + bucket_store, + grad_store, non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, @@ -429,35 +433,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(moe_grad_list) > 0: flat_grads_list = list( - moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) + moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter( recieved_grad, flat_grads_list, - group=self.moe_extra_dp_pg, + group=bucket_store.moe_extra_dp_pg, ) - param_slice = self._world_size // self.moe_extra_dp_pg_size + param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: split_recieved_grad = _unflatten_dense_tensors( split_recieved_grad, moe_grad_in_bucket_current_rank ) for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(real_grad, param_slice, group_id, param_id) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad( + grad_store, real_grad, param_slice, group_id, param_id + ) - self._bucket_store.reset() + bucket_store.reset() - def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: + @staticmethod + def update_unpartitoned_grad( + bucket_store: BucketStore, + grad_store: GradientStore, + origin_grad_list: List, + flat_grad_list: List, + group_id: int, + ) -> None: for rank, grad_list in enumerate(origin_grad_list): sync_tensor(flat_grad_list[rank], grad_list) for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, self._world_size, group_id, param_id, rank) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank) - def _update_partitoned_grad( - self, + @staticmethod + def update_partitoned_grad( + bucket_store: BucketStore, + grad_store: GradientStore, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, @@ -465,23 +480,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, partition_num, group_id, param_id) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id) - def _add_grad( - self, + @staticmethod + def add_grad( + grad_store: GradientStore, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0, ) -> None: - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) - def _add_to_bucket(self, param, group_id): + @staticmethod + def add_to_bucket( + param: nn.Parameter, + group_id: int, + bucket_store: BucketStore, + param_store: ParameterStore, + grad_store: GradientStore, + ): param_size = param.numel() # check if the bucket is full @@ -489,13 +512,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # or got a grad of param from another group # after reduction, the bucket will be empty if ( - self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size - or group_id != self._bucket_store.current_group_id + bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size + or group_id != bucket_store.current_group_id ): - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store) - padding_size = self._param_store.get_param_padding_size(param) - self._bucket_store.add_param_grad(group_id, param, padding_size) + padding_size = param_store.get_param_padding_size(param) + bucket_store.add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods @@ -503,7 +526,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not ( - self._partition_grads and not self.require_grad_sync + self._grad_store._partition_grads and not self._grad_store.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: @@ -511,31 +534,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper): loss.backward(retain_graph=retain_graph) - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return - self._reduce_grad(self._partition_grads) + self._reduce_grad(self._grad_store._partition_grads) # clear reduced grads - if self._overlap_communication: + if self._bucket_store._overlap_communication: get_accelerator().synchronize() self.zero_grad() def backward_by_grad(self, tensor, grad): assert not ( - self._partition_grads and not self.require_grad_sync + self._grad_store._partition_grads and not self._grad_store.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return - self._reduce_grad(self._partition_grads) + self._reduce_grad(self._grad_store._partition_grads) # clear reduced grads - if self._overlap_communication: + if self._bucket_store._overlap_communication: get_accelerator().synchronize() self.zero_grad() @@ -566,7 +589,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def step(self, closure=None): assert closure is None, "closure is not supported by step()" - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): @@ -585,7 +608,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._partition_grads else self._local_rank + grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] real_working_params[group_id] = [] @@ -598,14 +621,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: # moe hybrid zero - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) - if self._partition_grads: + if self._grad_store._partition_grads: grad = grads else: - param_slice = self._world_size // self.moe_extra_dp_pg_size + param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size grad = grads[ - self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice + self._bucket_store.moe_extra_dp_pg_rank + * param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1) + * param_slice ] grad = flatten(grad) else: @@ -674,25 +699,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self.moe_extra_dp_pg_size) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] dist.all_gather( all_splited_param, splited_param.to(device).to(self._dtype), - group=self.moe_extra_dp_pg, + group=self._bucket_store.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._world_size) + for _ in range(self._bucket_store.zero_world_size) ] dist.all_gather( all_splited_param, splited_param.to(device).to(self._dtype), - group=self.dp_pg, + group=self._bucket_store.torch_pg, ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -720,7 +745,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg) total_norm = total_norm_cuda.item() else: @@ -738,7 +763,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, - group=self.dp_pg, + group=self._bucket_store.torch_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -773,27 +798,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad and param.grad is not None: - self._add_to_bucket(param, group_id) + LowLevelZeroOptimizer.add_to_bucket( + param, + group_id, + self._bucket_store, + self._param_store, + self._grad_store, + ) - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) def _reduce_grad(self, partition_grad): # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients - if not partition_grad and not self._overlap_communication: + if not partition_grad and not self._bucket_store._overlap_communication: self._sync_grad() else: - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) # this context comes from pytorch DDP @contextmanager def no_sync(self): - old_require_grad_sync = self.require_grad_sync - self.require_grad_sync = False + old_require_grad_sync = self._grad_store.require_grad_sync + self._grad_store.require_grad_sync = False try: yield finally: - self.require_grad_sync = old_require_grad_sync + self._grad_store.require_grad_sync = old_require_grad_sync ############## # State Dict # @@ -833,16 +864,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] - dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) else: gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.zero_world_size) ] - dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -862,17 +895,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_idx, state in zero_state_dict["state"].items(): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size + padding_size = ( + self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size + ) % self._bucket_store.zero_world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size) + zero_state_dict["state"][param_idx][k] = ( + v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone() + ) else: - v_list = v.split(v.numel() // self._world_size) - zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + v_list = v.split(v.numel() // self._bucket_store.zero_world_size) + zero_state_dict["state"][param_idx][k] = ( + v_list[self._bucket_store.zero_local_rank].detach().clone() + ) self.optim.load_state_dict(zero_state_dict) @@ -904,16 +943,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] - dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) else: state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.zero_world_size) ] - dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -944,14 +985,30 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(p): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + master_param.copy_( + working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank] + ) if hasattr(self, "master_moe_params"): for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): master_moe_param.copy_(working_moe_param) + def remove_hooks(self) -> None: + """remove the registered hooks + + Args: + plugin (LowLevelZeroPlugin): the plugin to bound this method. + """ + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad: + assert hasattr(param, "_grad_handle") + param._grad_handle.remove() + delattr(param, "_grad_handle") + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 4908b2d4f..8c59f430c 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -80,7 +80,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): skipped_models.append(name) continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - get_accelerator().empty_cache() if err is None: diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 104ede981..c38570f85 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -64,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) - grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank + grad_index = ( + 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)