mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[gemini] use compute_chunk to find next chunk
This commit is contained in:
@@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
def __init__(
|
||||
@@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
prefetch = []
|
||||
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
|
||||
for chunk in self.chunk_manager.compute_list[i]:
|
||||
if len(prefetch) >= max_prefetch:
|
||||
break
|
||||
if chunk not in prefetch:
|
||||
prefetch.append(chunk)
|
||||
if len(prefetch) >= max_prefetch:
|
||||
break
|
||||
return prefetch
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
need_mem_stats: bool = True
|
||||
@@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
return [] # TODO @botbw: implement prefetching for auto
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
|
Reference in New Issue
Block a user