mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[gemini] optimize reduce scatter d2h copy (#5760)
* [gemini] optimize reduce scatter d2h copy
* [fix] fix missing reduce variable
* [refactor] remove legacy async reduce scatter code
* [gemini] missing sync
* Revert "[refactor] remove legacy async reduce scatter code"
This reverts commit 58ad76d466
.
* [gemini] further optimize with async all reduce
* [fix] pass flag from manager to chunk
This commit is contained in:
@@ -316,12 +316,13 @@ class Chunk:
|
||||
if self.shard_device.type == "cpu":
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False, non_blocking=False):
|
||||
"""Move the shard tensor in the chunk.
|
||||
|
||||
Args:
|
||||
device: the device to which the shard will move
|
||||
force_copy: if True, copy function is called mandatorily
|
||||
non_blocking: if True, the operation is non-blocking, the caller is responsible for synchronization
|
||||
"""
|
||||
# sanity check
|
||||
assert not self.is_gathered
|
||||
@@ -329,7 +330,7 @@ class Chunk:
|
||||
# just use another way for the movement
|
||||
if not self.optim_sync_flag:
|
||||
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
|
||||
self.__paired_shard_move()
|
||||
self.__paired_shard_move(non_blocking=non_blocking)
|
||||
self.optim_sync_flag = True
|
||||
return
|
||||
|
||||
@@ -339,7 +340,7 @@ class Chunk:
|
||||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
@@ -349,11 +350,11 @@ class Chunk:
|
||||
|
||||
if self.pin_memory:
|
||||
if force_copy or not self.cpu_vis_flag:
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
self.cpu_shard.copy_(self.cuda_shard, non_blocking=non_blocking)
|
||||
# if cpu_shard has been visited
|
||||
# copy operation is not need
|
||||
else:
|
||||
self.cpu_shard = self.cuda_shard.cpu()
|
||||
self.cpu_shard = self.cuda_shard.to("cpu", non_blocking=non_blocking)
|
||||
self.cpu_vis_flag = True
|
||||
self.cuda_shard = None
|
||||
else:
|
||||
@@ -542,7 +543,7 @@ class Chunk:
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
|
||||
def __paired_shard_move(self):
|
||||
def __paired_shard_move(self, non_blocking=False):
|
||||
assert self.paired_chunk is not None, "chunks should be paired before training"
|
||||
optim_chunk = self.paired_chunk
|
||||
assert self.chunk_size == optim_chunk.chunk_size
|
||||
@@ -550,7 +551,7 @@ class Chunk:
|
||||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
|
@@ -117,7 +117,7 @@ class ChunkManager:
|
||||
return None
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
chunk.shard_move(get_accelerator().get_current_device(), non_blocking=async_access)
|
||||
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return maybe_work
|
||||
|
Reference in New Issue
Block a user