diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..eb8db6212 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -361,6 +361,7 @@ class GeminiPlugin(DPPluginBase): enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, + enable_async_reduce: bool = True, verbose: bool = False, ) -> None: super().__init__() @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..8f048f0b7 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -164,6 +164,8 @@ class Chunk: self.l2_norm = None self.grad_chunk = None + # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) + self.grad_reduce_work = None @property def memory_usage(self) -> Dict[str, int]: @@ -244,7 +246,7 @@ class Chunk: assert self.cuda_shard is not None # only check on CUDA valid_tensor = self.cuda_shard[: self.valid_end] - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any() def set_l2_norm(self) -> None: """Record l2 norm of this chunks on CUDA.""" @@ -374,37 +376,49 @@ class Chunk: if self.is_gathered: self.__scatter() - def reduce(self): + def reduce(self, async_op: bool = False): """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered - + assert self.grad_reduce_work is None if self.pg_size == 1: # tricky code here # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) elif self.keep_gathered: # we use all-reduce here - dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) - if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) + self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) + if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce( + self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op + ) else: self.cuda_shard = torch.empty( self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + self.grad_reduce_work = dist.reduce_scatter( + self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op + ) + if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) free_storage(self.cuda_global_chunk) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) + def wait_async_reduce(self) -> None: + if self.grad_reduce_work is not None: + self.grad_reduce_work.wait() + self.grad_reduce_work = None + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ Make a transition of the tensor into the next state. diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..6ec595914 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -41,7 +41,7 @@ class ChunkManager: self.reuse_fp16_chunk = reuse_fp16_chunk # Whether model is accumulating gradients, self.accumulating_grads = False - self.overflow_counter = 0 + self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) def register_tensor( self, @@ -143,12 +143,12 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, chunk: Chunk) -> bool: + def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) - chunk.reduce() + chunk.reduce(async_op=async_op) self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True @@ -272,7 +272,7 @@ class ChunkManager: return grad_chunk def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: - """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" + """Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction.""" assert chunk.grad_chunk is not None diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..23f6ee683 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool = True, extra_dp_group: Optional[ProcessGroup] = None, verbose: bool = False, + enable_async_reduce: bool = True, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper): if is_ddp_ignored(p): continue if p.requires_grad: + assert not hasattr(p, "_grad_handle") p._grad_handle = p.register_hook( partial( GeminiDDP.grad_handle, @@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper): master_weights=self.master_weights, enable_gradient_accumulation=self.enable_gradient_accumulation, p=p, + async_reduce=enable_async_reduce, ) ) @@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper): setattr(param, "_gemini_reduced", False) def _post_backward(self): + for param in self.param2name: + if hasattr(param, "_release_grad_chunk_cb"): + param._release_grad_chunk_cb() + delattr(param, "_release_grad_chunk_cb") + if self.chunk_manager.accessed_mem != 0: error_params = ["Reduction failed at followed parameters:"] for param in self.param2name: @@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, + async_reduce: bool, ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) @@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper): grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = chunk_manager.reduce_chunk(grad_chunk) - if reduced: - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) - else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce) + if reduced: # if not async, can release immediately, else release in when work finished + if async_reduce: + # dirty fix by installing callback + assert not hasattr(p, "_release_grad_chunk_cb") + + def _release_grad_chunk_cb(): + grad_chunk.wait_async_reduce() + GeminiDDP.release_grad_chunk_handle( + chunk_manager, + grads_device, + master_weights, + enable_gradient_accumulation, + p, + chunk, + grad_chunk, + ) + + p._release_grad_chunk_cb = _release_grad_chunk_cb else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + GeminiDDP.release_grad_chunk_handle( + chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk + ) return empty_grad + @staticmethod + def release_grad_chunk_handle( + chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk + ): + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + else: + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18918eabc..1d755c417 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): self.module = module def check_local_overflow(self) -> bool: - return self.module.chunk_manager.overflow_counter > 0 + return self.module.chunk_manager.overflow_counter.item() > 0 def pre_zero_grad(self) -> None: - self.module.chunk_manager.overflow_counter = 0 + self.module.chunk_manager.overflow_counter.zero_() class GeminiOptimizer(OptimizerWrapper): diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 5cc602181..6f91ff7b7 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -76,6 +76,8 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) + args = parser.parse_args() colossalai.launch_from_torch() @@ -110,6 +112,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, + enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 257311328..51b20c400 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -34,7 +34,8 @@ def check_equal(param, param_cp): @parameterize("init_device", [None, torch.device("cpu")]) @parameterize("keep_gathered", [True, False]) @parameterize("pin_memory", [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory): +@parameterize("async_op", [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): world_size = torch.distributed.get_world_size() pg = _get_default_group() my_chunk = Chunk( @@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.can_reduce - my_chunk.reduce() + my_chunk.reduce(async_op) assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + if async_op: + my_chunk.wait_async_reduce() + if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 570a0aa42..4279793d7 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) +@parameterize("enable_async_reduce", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, master_weights: bool = True, + enable_async_reduce=True, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd( config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP( - model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + model, + config_dict, + init_device, + pin_memory=True, + **placement_config, + master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index fd0e9fd7c..6e6c27e3f 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("enable_async_reduce", [False, True]) def exam_gemini_grad_acc( - placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool + placement_config, + keep_gathered: bool, + model_name: str, + master_weights: bool, + use_grad_checkpoint: bool, + enable_async_reduce: bool, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -81,10 +87,13 @@ def exam_gemini_grad_acc( pin_memory=True, enable_gradient_accumulation=True, master_weights=master_weights, + enable_async_reduce=enable_async_reduce, **placement_config, ) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0) + gemini_optim = GeminiOptimizer( + optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce + ) rank = dist.get_rank() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 0a9bac092..7a1609ca5 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) -def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): +@parameterize("enable_async_reduce", [False, True]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool): set_seed(1912) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): chunk_init_device=init_device, pin_memory=True, master_weights=master_weights, + enable_async_reduce=enable_async_reduce, **placement_config, ) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a9366e7bc..c610259b2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -73,7 +73,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): +@parameterize("enable_async_reduce", [False, True]) +def exam_model_step( + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True +): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -96,7 +99,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False model = GeminiDDP( - model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + model, + config_dict, + **placement_config, + mixed_precision=mixed_precision, + master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) optimizer = HybridAdam(model.parameters(), lr=1e-3)