mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fix
This commit is contained in:
@@ -61,12 +61,13 @@ class Chunk:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
zero_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False,
|
||||
extra_dp_group: ProcessGroup = None,
|
||||
) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
@@ -76,7 +77,7 @@ class Chunk:
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ProcessGroup): the process group of this chunk
|
||||
zero_group (ProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||
The default value is None, which is the current GPU
|
||||
@@ -90,9 +91,11 @@ class Chunk:
|
||||
self.chunk_size = chunk_size
|
||||
self.utilized_size = 0
|
||||
|
||||
self.torch_pg = process_group
|
||||
self.torch_pg = zero_group
|
||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
self.extra_dp_group = extra_dp_group
|
||||
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
|
||||
|
||||
# the chunk size should be divisible by the dp degree
|
||||
if not keep_gathered:
|
||||
@@ -384,14 +387,20 @@ class Chunk:
|
||||
# just move cuda_global_chunk to cuda_shard
|
||||
# the communication is not necessary
|
||||
self.__scatter()
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
elif self.keep_gathered:
|
||||
# we use all-reduce here
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
@@ -608,10 +617,11 @@ class Chunk:
|
||||
# grad chunk is not initialized
|
||||
grad_chunk = Chunk(
|
||||
chunk_size=self.chunk_size,
|
||||
process_group=self.torch_pg,
|
||||
zero_group=self.torch_pg,
|
||||
dtype=self.dtype,
|
||||
keep_gathered=self.keep_gathered,
|
||||
pin_memory=self.pin_memory,
|
||||
extra_dp_group=self.extra_dp_group,
|
||||
)
|
||||
grad_chunk.num_tensors = self.num_tensors
|
||||
grad_chunk.utilized_size = self.utilized_size
|
||||
@@ -640,4 +650,4 @@ class Chunk:
|
||||
self.grad_chunk.l2_norm = None
|
||||
alloc_storage(self.grad_chunk.cuda_global_chunk)
|
||||
|
||||
return self.grad_chunk
|
||||
return self.grad_chunk
|
Reference in New Issue
Block a user