mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[gemini] prefetch chunks
This commit is contained in:
@@ -357,14 +357,14 @@ class Chunk:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if not self.is_gathered:
|
||||
self.__gather()
|
||||
return self.__gather(async_op=async_access)
|
||||
self.__update_tensors_ptr()
|
||||
return None
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||
@@ -498,17 +498,19 @@ class Chunk:
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def __gather(self):
|
||||
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
return work
|
||||
return None
|
||||
|
||||
def __scatter(self):
|
||||
if self.keep_gathered:
|
||||
|
Reference in New Issue
Block a user