[gemini] optimize reduce scatter d2h copy

This commit is contained in:
hxwang 2024-05-28 14:23:22 +00:00
parent b96c6390f4
commit b5ae587d50
3 changed files with 41 additions and 56 deletions

View File

@ -368,6 +368,11 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu": if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
if placement_policy == "auto" and enable_async_reduce:
logging.warning(
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),

View File

@ -339,7 +339,7 @@ class Chunk:
if self.cuda_shard: if self.cuda_shard:
return return
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True)
if not self.pin_memory: if not self.pin_memory:
self.cpu_shard = None self.cpu_shard = None
@ -349,7 +349,7 @@ class Chunk:
if self.pin_memory: if self.pin_memory:
if force_copy or not self.cpu_vis_flag: if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard) self.cpu_shard.copy_(self.cuda_shard, non_blocking=True)
# if cpu_shard has been visited # if cpu_shard has been visited
# copy operation is not need # copy operation is not need
else: else:
@ -547,7 +547,7 @@ class Chunk:
# only be called when optimizer state is in CPU memory # only be called when optimizer state is in CPU memory
# the grad and param should be in the same device # the grad and param should be in the same device
assert self.cuda_shard is None assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True)
# avoid to transform FP32 in CPU # avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype) self.cuda_shard = temp.to(self.dtype)

View File

@ -145,6 +145,12 @@ class GeminiDDP(ModelWrapper):
self.extra_dp_group = extra_dp_group self.extra_dp_group = extra_dp_group
self.master_weights = master_weights self.master_weights = master_weights
self.enable_async_reduce = enable_async_reduce
if enable_async_reduce:
self.async_reduce_stream = torch.cuda.Stream()
else:
self.async_reduce_stream = None
self._logger = get_dist_logger() self._logger = get_dist_logger()
@ -174,6 +180,7 @@ class GeminiDDP(ModelWrapper):
super().__init__(module) super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers() self._cast_buffers()
# register grad hook # register grad hook
for p in module.parameters(): for p in module.parameters():
if is_ddp_ignored(p): if is_ddp_ignored(p):
@ -189,7 +196,7 @@ class GeminiDDP(ModelWrapper):
master_weights=self.master_weights, master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation, enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p, p=p,
async_reduce=enable_async_reduce, async_reduce_stream=self.async_reduce_stream,
) )
) )
@ -337,10 +344,8 @@ class GeminiDDP(ModelWrapper):
setattr(param, "_gemini_reduced", False) setattr(param, "_gemini_reduced", False)
def _post_backward(self): def _post_backward(self):
for param in self.param2name: if self.enable_async_reduce:
if hasattr(param, "_release_grad_chunk_cb"): self.async_reduce_stream.synchronize()
param._release_grad_chunk_cb()
delattr(param, "_release_grad_chunk_cb")
if self.chunk_manager.accessed_mem != 0: if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"] error_params = ["Reduction failed at followed parameters:"]
@ -379,7 +384,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool, master_weights: bool,
enable_gradient_accumulation: bool, enable_gradient_accumulation: bool,
p: nn.Parameter, p: nn.Parameter,
async_reduce: bool, async_reduce_stream: Optional[torch.cuda.Stream] = None,
): ):
setattr(p, "_gemini_reduced", True) setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
@ -415,56 +420,31 @@ class GeminiDDP(ModelWrapper):
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else: else:
grad_chunk.add_tensor_to_chunk_slice(p, grad) grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
if reduced: # if not async, can release immediately, else release in when work finished
if async_reduce:
# dirty fix by installing callback
assert not hasattr(p, "_release_grad_chunk_cb")
def _release_grad_chunk_cb(): with torch.cuda.stream(async_reduce_stream):
grad_chunk.wait_async_reduce() chunk_manager.reduce_chunk(grad_chunk)
GeminiDDP.release_grad_chunk_handle(
chunk_manager,
grads_device,
master_weights,
enable_gradient_accumulation,
p,
chunk,
grad_chunk,
)
p._release_grad_chunk_cb = _release_grad_chunk_cb if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else: else:
GeminiDDP.release_grad_chunk_handle( grad_chunk.cuda_shard.div_(chunk.pg_size)
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk if chunk.extra_dp_group is not None:
) grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
return empty_grad # check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
@staticmethod # record l2 norm for gradient clipping. flag is bound to fp16 chunk
def release_grad_chunk_handle( if chunk.l2_norm_flag:
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk grad_chunk.set_l2_norm()
): chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not chunk_manager.reuse_fp16_chunk: if not (master_weights) or (enable_gradient_accumulation):
if chunk.keep_gathered: chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)