[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hanks
2024-08-08 15:55:01 +08:00
committed by GitHub
parent 7739629b9d
commit b480eec738
14 changed files with 602 additions and 14 deletions

View File

@@ -166,6 +166,7 @@ class Chunk:
self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None
self.fp8_communication = False
@property
def memory_usage(self) -> Dict[str, int]:
@@ -521,9 +522,17 @@ class Chunk:
alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous()
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
)
else:
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
self.cuda_shard = None
self.is_gathered = True

View File

@@ -26,6 +26,7 @@ class ChunkManager:
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
fp8_communication: bool = False,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@@ -44,6 +45,7 @@ class ChunkManager:
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
self.fp8_communication = fp8_communication
def register_tensor(
self,
@@ -101,6 +103,8 @@ class ChunkManager:
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)
if self.fp8_communication:
chunk.fp8_communication = True
chunk_group.append(chunk)
chunk.append_tensor(tensor)

View File

@@ -98,6 +98,7 @@ class GeminiDDP(ModelWrapper):
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -122,6 +123,8 @@ class GeminiDDP(ModelWrapper):
verbose=verbose,
max_prefetch=max_prefetch,
)
if fp8_communication:
self.chunk_manager.fp8_communication = True
self.gemini_manager = GeminiManager(
placement_policy,
self.chunk_manager,