From 3788fefc7a24ae4da18a29ca69e6d3b1473d306c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 16 Apr 2024 17:49:21 +0800 Subject: [PATCH] [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements --- .../low_level/bookkeeping/bucket_store.py | 2 + colossalai/zero/low_level/low_level_optim.py | 67 +++++++++++++++---- requirements/requirements.txt | 2 +- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index f395fc60e..2ebc704f7 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -11,7 +11,9 @@ from .base_store import BaseStore class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) + self.reset_all() + def reset_all(self) -> None: # init self.current_group_id = 0 self._num_elements_in_bucket = 0 diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bbbaf13b5..cbcf72697 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -40,7 +40,13 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -273,11 +279,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Backward Reduction Hook # ########################### - def _grad_handler(self, param, group_id, grad): + def _grad_handler(self, group_id, param): # if run with no_sync context, would not sync grad when backward if self.require_grad_sync: self._add_to_bucket(param, group_id) - return grad def _attach_reduction_hook(self): # we iterate over the working params @@ -286,7 +291,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_hook(partial(self._grad_handler, param, group_id)) + param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) ####################### # Reduction Functions # @@ -415,7 +420,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) self._update_partitoned_grad( - non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + non_moe_grad_in_bucket_current_rank, + recieved_grad, + group_id, + 1, ) if len(moe_grad_list) > 0: @@ -423,7 +431,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) + dist.reduce_scatter( + recieved_grad, + flat_grads_list, + group=self.moe_extra_dp_pg, + ) param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: @@ -444,14 +456,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._add_grad(grad, self._world_size, group_id, param_id, rank) def _update_partitoned_grad( - self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + self, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) self._add_grad(grad, partition_num, group_id, param_id) - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: @@ -534,6 +557,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if param.grad is not None: param.grad.detach() param.grad.zero_() + self._bucket_store.reset_all() #################### # Update Parameter # @@ -655,14 +679,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for _ in range(self.moe_extra_dp_pg_size) ] dist.all_gather( - all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) + dist.all_gather( + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.dp_pg, + ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -685,7 +715,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -698,10 +730,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -920,5 +956,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: if hasattr(self, "moe_master_to_working_map"): - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + return { + **self._param_store.master_to_working_param, + **self.moe_master_to_working_map, + } return self._param_store.master_to_working_param diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 095617d76..fd97f5c5a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=1.12 +torch>=2.1.0 safetensors einops pydantic