diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ab554d21d..474b78aa2 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -369,6 +369,11 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + if placement_policy == "auto" and enable_async_reduce: + logging.warning( + f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set." + ) + pin_memory = True self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index ed5b96519..18fbf8fc3 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -316,12 +316,13 @@ class Chunk: if self.shard_device.type == "cpu": self.cuda_shard = None - def shard_move(self, device: torch.device, force_copy: bool = False): + def shard_move(self, device: torch.device, force_copy: bool = False, non_blocking=False): """Move the shard tensor in the chunk. Args: device: the device to which the shard will move force_copy: if True, copy function is called mandatorily + non_blocking: if True, the operation is non-blocking, the caller is responsible for synchronization """ # sanity check assert not self.is_gathered @@ -329,7 +330,7 @@ class Chunk: # just use another way for the movement if not self.optim_sync_flag: assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA" - self.__paired_shard_move() + self.__paired_shard_move(non_blocking=non_blocking) self.optim_sync_flag = True return @@ -339,7 +340,7 @@ class Chunk: if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking) if not self.pin_memory: self.cpu_shard = None @@ -349,11 +350,11 @@ class Chunk: if self.pin_memory: if force_copy or not self.cpu_vis_flag: - self.cpu_shard.copy_(self.cuda_shard) + self.cpu_shard.copy_(self.cuda_shard, non_blocking=non_blocking) # if cpu_shard has been visited # copy operation is not need else: - self.cpu_shard = self.cuda_shard.cpu() + self.cpu_shard = self.cuda_shard.to("cpu", non_blocking=non_blocking) self.cpu_vis_flag = True self.cuda_shard = None else: @@ -542,7 +543,7 @@ class Chunk: free_storage(self.cuda_global_chunk) self.is_gathered = False - def __paired_shard_move(self): + def __paired_shard_move(self, non_blocking=False): assert self.paired_chunk is not None, "chunks should be paired before training" optim_chunk = self.paired_chunk assert self.chunk_size == optim_chunk.chunk_size @@ -550,7 +551,7 @@ class Chunk: # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 36e7ee57b..45066ca89 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -117,7 +117,7 @@ class ChunkManager: return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_accelerator().get_current_device()) + chunk.shard_move(get_accelerator().get_current_device(), non_blocking=async_access) maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) return maybe_work diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 050643dfa..6f6064000 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -147,6 +147,12 @@ class GeminiDDP(ModelWrapper): self.extra_dp_group = extra_dp_group self.master_weights = master_weights + self.enable_async_reduce = enable_async_reduce + + if enable_async_reduce: + self.async_reduce_stream = torch.cuda.Stream() + else: + self.async_reduce_stream = None self._logger = get_dist_logger() @@ -176,6 +182,7 @@ class GeminiDDP(ModelWrapper): super().__init__(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() + # register grad hook for p in module.parameters(): if is_ddp_ignored(p): @@ -191,7 +198,7 @@ class GeminiDDP(ModelWrapper): master_weights=self.master_weights, enable_gradient_accumulation=self.enable_gradient_accumulation, p=p, - async_reduce=enable_async_reduce, + async_reduce_stream=self.async_reduce_stream, ) ) @@ -339,10 +346,8 @@ class GeminiDDP(ModelWrapper): setattr(param, "_gemini_reduced", False) def _post_backward(self): - for param in self.param2name: - if hasattr(param, "_release_grad_chunk_cb"): - param._release_grad_chunk_cb() - delattr(param, "_release_grad_chunk_cb") + if self.enable_async_reduce: + self.async_reduce_stream.synchronize() if self.chunk_manager.accessed_mem != 0: error_params = ["Reduction failed at followed parameters:"] @@ -381,7 +386,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, - async_reduce: bool, + async_reduce_stream: Optional[torch.cuda.Stream] = None, ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) @@ -417,56 +422,35 @@ class GeminiDDP(ModelWrapper): grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce) - if reduced: # if not async, can release immediately, else release in when work finished - if async_reduce: - # dirty fix by installing callback - assert not hasattr(p, "_release_grad_chunk_cb") - def _release_grad_chunk_cb(): - grad_chunk.wait_async_reduce() - GeminiDDP.release_grad_chunk_handle( - chunk_manager, - grads_device, - master_weights, - enable_gradient_accumulation, - p, - chunk, - grad_chunk, - ) + if async_reduce_stream is not None: + async_reduce_stream.wait_stream(torch.cuda.current_stream()) - p._release_grad_chunk_cb = _release_grad_chunk_cb - else: - GeminiDDP.release_grad_chunk_handle( - chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk - ) - return empty_grad - - @staticmethod - def release_grad_chunk_handle( - chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk - ): - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) - else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) - else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + with torch.cuda.stream(async_reduce_stream): + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None)) + if reduced: + grad_chunk.wait_async_reduce() + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + else: + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True)