mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[gemini] prefetch chunks
This commit is contained in:
@@ -111,15 +111,16 @@ class ChunkManager:
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return maybe_work
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA."""
|
||||
@@ -251,10 +252,11 @@ class ChunkManager:
|
||||
for k, v in usage.items():
|
||||
self.total_mem[k] += v
|
||||
|
||||
def __add_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.access_chunk()
|
||||
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
maybe_work = chunk.access_chunk(async_access=async_access)
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.accessed_mem += chunk.chunk_mem
|
||||
return maybe_work
|
||||
|
||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.release_chunk()
|
||||
|
Reference in New Issue
Block a user