From a55a9e298bad86297eca89923d24d5db9b1f0aaf Mon Sep 17 00:00:00 2001 From: hxwang Date: Mon, 20 May 2024 02:21:17 +0000 Subject: [PATCH 1/2] [gemini] init auto policy prefetch --- colossalai/zero/gemini/placement_policy.py | 37 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index e9e871b46..a48f8d0d0 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -19,7 +19,7 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -40,9 +40,8 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - @abstractmethod def get_prefetch_chunks(self) -> List[Chunk]: - raise NotImplementedError + return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): @@ -116,12 +115,14 @@ class StaticPlacementPolicy(PlacementPolicy): can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + break_flag = False for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: + break_flag = True break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= can_prefetch: + if break_flag: break return prefetch @@ -232,9 +233,31 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - return [] # TODO @botbw: implement prefetching for auto - + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + # modified from self.evict_tensors + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(get_accelerator().get_current_device()) + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + break_flag = False + for chunk in self.gemini_manager.compute_list[i]: + chunk: Chunk + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break_flag = True + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + if break_flag: + break + return prefetch class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { From f1918e18a5051113290f9702a47e11266db492f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 May 2024 03:00:06 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/zero/gemini/placement_policy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index a48f8d0d0..cae5cc202 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -237,12 +237,14 @@ class AutoPlacementPolicy(PlacementPolicy): if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors - cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(get_accelerator().get_current_device()) + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") used_cuda_model_data = self.chunk_manager.total_mem["cuda"] total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - + prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] @@ -259,6 +261,7 @@ class AutoPlacementPolicy(PlacementPolicy): break return prefetch + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { "auto": AutoPlacementPolicy,