[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
This commit is contained in:
flybird11111
2023-11-16 21:03:04 +08:00
committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
10 changed files with 96 additions and 137 deletions

View File

@@ -61,12 +61,13 @@ class Chunk:
def __init__(
self,
chunk_size: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
@@ -76,7 +77,7 @@ class Chunk:
Args:
chunk_size (int): the number of elements in the chunk
process_group (ProcessGroup): the process group of this chunk
zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
@@ -90,9 +91,11 @@ class Chunk:
self.chunk_size = chunk_size
self.utilized_size = 0
self.torch_pg = process_group
self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
# the chunk size should be divisible by the dp degree
if not keep_gathered:
@@ -384,14 +387,20 @@ class Chunk:
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
@@ -608,10 +617,11 @@ class Chunk:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
@@ -640,4 +650,4 @@ class Chunk:
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk
return self.grad_chunk