mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -166,6 +166,7 @@ class Chunk:
|
||||
self.grad_chunk = None
|
||||
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
||||
self.grad_reduce_work = None
|
||||
self.fp8_communication = False
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
@@ -521,9 +522,17 @@ class Chunk:
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
if self.fp8_communication:
|
||||
assert async_op == False, "fp8 all-gather does not support async_op!"
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
|
||||
work = all_gather_into_tensor_flat_fp8(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
|
||||
)
|
||||
else:
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
Reference in New Issue
Block a user