mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fix
This commit is contained in:
@@ -38,7 +38,8 @@ class ChunkManager:
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
zero_group: ProcessGroup,
|
||||
extra_dp_group: ProcessGroup = None,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
@@ -76,15 +77,16 @@ class ChunkManager:
|
||||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
dp_size = dist.get_world_size(process_group)
|
||||
dp_size = dist.get_world_size(zero_group)
|
||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=process_group,
|
||||
zero_group=zero_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
extra_dp_group=extra_dp_group,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
@@ -288,4 +290,4 @@ class ChunkManager:
|
||||
# Release accumulated_grad
|
||||
free_storage(accumulated_grad)
|
||||
|
||||
return grad_chunk
|
||||
return grad_chunk
|
Reference in New Issue
Block a user