[gemini] use compute_chunk to find next chunk

This commit is contained in:
hxwang
2024-05-16 13:17:26 +08:00
parent b2e9745888
commit 4148ceed9f
5 changed files with 52 additions and 79 deletions

View File

@@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
raise NotImplementedError
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
@@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
prefetch = []
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
for chunk in self.chunk_manager.compute_list[i]:
if len(prefetch) >= max_prefetch:
break
if chunk not in prefetch:
prefetch.append(chunk)
if len(prefetch) >= max_prefetch:
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
@@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
else:
grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
return [] # TODO @botbw: implement prefetching for auto
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {