mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)
* [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -164,6 +164,8 @@ class Chunk:
|
||||
self.l2_norm = None
|
||||
|
||||
self.grad_chunk = None
|
||||
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
||||
self.grad_reduce_work = None
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
@@ -244,7 +246,7 @@ class Chunk:
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any()
|
||||
|
||||
def set_l2_norm(self) -> None:
|
||||
"""Record l2 norm of this chunks on CUDA."""
|
||||
@@ -374,37 +376,49 @@ class Chunk:
|
||||
if self.is_gathered:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
def reduce(self, async_op: bool = False):
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
assert self.grad_reduce_work is None
|
||||
if self.pg_size == 1:
|
||||
# tricky code here
|
||||
# just move cuda_global_chunk to cuda_shard
|
||||
# the communication is not necessary
|
||||
self.__scatter()
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||
elif self.keep_gathered:
|
||||
# we use all-reduce here
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
|
||||
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
||||
self.wait_async_reduce()
|
||||
self.grad_reduce_work = dist.all_reduce(
|
||||
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
|
||||
)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
self.grad_reduce_work = dist.reduce_scatter(
|
||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
self.wait_async_reduce()
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
|
||||
def wait_async_reduce(self) -> None:
|
||||
if self.grad_reduce_work is not None:
|
||||
self.grad_reduce_work.wait()
|
||||
self.grad_reduce_work = None
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
Make a transition of the tensor into the next state.
|
||||
|
@@ -41,7 +41,7 @@ class ChunkManager:
|
||||
self.reuse_fp16_chunk = reuse_fp16_chunk
|
||||
# Whether model is accumulating gradients,
|
||||
self.accumulating_grads = False
|
||||
self.overflow_counter = 0
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
|
||||
def register_tensor(
|
||||
self,
|
||||
@@ -143,12 +143,12 @@ class ChunkManager:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
|
||||
"""Reduce or all reduce the chunk."""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
chunk.reduce()
|
||||
chunk.reduce(async_op=async_op)
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return True
|
||||
@@ -272,7 +272,7 @@ class ChunkManager:
|
||||
return grad_chunk
|
||||
|
||||
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
|
||||
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
|
||||
"""Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction."""
|
||||
|
||||
assert chunk.grad_chunk is not None
|
||||
|
||||
|
@@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
|
||||
master_weights: bool = True,
|
||||
extra_dp_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||
@@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
assert not hasattr(p, "_grad_handle")
|
||||
p._grad_handle = p.register_hook(
|
||||
partial(
|
||||
GeminiDDP.grad_handle,
|
||||
@@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
|
||||
master_weights=self.master_weights,
|
||||
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
||||
p=p,
|
||||
async_reduce=enable_async_reduce,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
|
||||
setattr(param, "_gemini_reduced", False)
|
||||
|
||||
def _post_backward(self):
|
||||
for param in self.param2name:
|
||||
if hasattr(param, "_release_grad_chunk_cb"):
|
||||
param._release_grad_chunk_cb()
|
||||
delattr(param, "_release_grad_chunk_cb")
|
||||
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
error_params = ["Reduction failed at followed parameters:"]
|
||||
for param in self.param2name:
|
||||
@@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
|
||||
master_weights: bool,
|
||||
enable_gradient_accumulation: bool,
|
||||
p: nn.Parameter,
|
||||
async_reduce: bool,
|
||||
):
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
@@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper):
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
||||
else:
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk)
|
||||
if reduced:
|
||||
if not chunk_manager.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
|
||||
if reduced: # if not async, can release immediately, else release in when work finished
|
||||
if async_reduce:
|
||||
# dirty fix by installing callback
|
||||
assert not hasattr(p, "_release_grad_chunk_cb")
|
||||
|
||||
def _release_grad_chunk_cb():
|
||||
grad_chunk.wait_async_reduce()
|
||||
GeminiDDP.release_grad_chunk_handle(
|
||||
chunk_manager,
|
||||
grads_device,
|
||||
master_weights,
|
||||
enable_gradient_accumulation,
|
||||
p,
|
||||
chunk,
|
||||
grad_chunk,
|
||||
)
|
||||
|
||||
p._release_grad_chunk_cb = _release_grad_chunk_cb
|
||||
else:
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||
# check overflow elements
|
||||
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
GeminiDDP.release_grad_chunk_handle(
|
||||
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||
)
|
||||
return empty_grad
|
||||
|
||||
@staticmethod
|
||||
def release_grad_chunk_handle(
|
||||
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||
):
|
||||
if not chunk_manager.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||
else:
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||
# check overflow elements
|
||||
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
|
@@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
return self.module.chunk_manager.overflow_counter > 0
|
||||
return self.module.chunk_manager.overflow_counter.item() > 0
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
self.module.chunk_manager.overflow_counter = 0
|
||||
self.module.chunk_manager.overflow_counter.zero_()
|
||||
|
||||
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
|
Reference in New Issue
Block a user