[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)

* [gemini] async grad chunk reduce (all-reduce&reduce-scatter)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] add test

* [gemini] rename func

* [gemini] update llama benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] use tensor counter

* [gemini] change default config in GeminiPlugin and GeminiDDP

* [chore] typo

* [gemini] fix sync issue & add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
botbw
2024-05-24 10:31:16 +08:00
committed by GitHub
parent 85946d4236
commit 2fc85abf43
11 changed files with 130 additions and 45 deletions

View File

@@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
enable_async_reduce: bool = True,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored(p):
continue
if p.requires_grad:
assert not hasattr(p, "_grad_handle")
p._grad_handle = p.register_hook(
partial(
GeminiDDP.grad_handle,
@@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p,
async_reduce=enable_async_reduce,
)
)
@@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
for param in self.param2name:
if hasattr(param, "_release_grad_chunk_cb"):
param._release_grad_chunk_cb()
delattr(param, "_release_grad_chunk_cb")
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
@@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
async_reduce: bool,
):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
@@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper):
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
if reduced: # if not async, can release immediately, else release in when work finished
if async_reduce:
# dirty fix by installing callback
assert not hasattr(p, "_release_grad_chunk_cb")
def _release_grad_chunk_cb():
grad_chunk.wait_async_reduce()
GeminiDDP.release_grad_chunk_handle(
chunk_manager,
grads_device,
master_weights,
enable_gradient_accumulation,
p,
chunk,
grad_chunk,
)
p._release_grad_chunk_cb = _release_grad_chunk_cb
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
GeminiDDP.release_grad_chunk_handle(
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
)
return empty_grad
@staticmethod
def release_grad_chunk_handle(
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
):
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)