[gemini] prefetch chunks

This commit is contained in:
hxwang
2024-05-15 16:51:44 +08:00
parent 785cd9a9c9
commit 6e38eafebe
4 changed files with 96 additions and 17 deletions

View File

@@ -111,15 +111,16 @@ class ChunkManager:
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk)
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage)
return maybe_work
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA."""
@@ -251,10 +252,11 @@ class ChunkManager:
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
return maybe_work
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()