[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
This commit is contained in:
flybird11111
2023-11-16 21:03:04 +08:00
committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
10 changed files with 96 additions and 137 deletions

View File

@@ -38,7 +38,8 @@ class ChunkManager:
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
@@ -76,15 +77,16 @@ class ChunkManager:
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = dist.get_world_size(process_group)
dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
chunk_size=chunk_size,
process_group=process_group,
zero_group=zero_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)
@@ -288,4 +290,4 @@ class ChunkManager:
# Release accumulated_grad
free_storage(accumulated_grad)
return grad_chunk
return grad_chunk