diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index c7bdd5e1f..341790a72 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -83,7 +83,7 @@ class ChunkManager: if chunk_group: # the chunk group is not empty # close the last chunk - self.__close_one_chunk(chunk_group[-1]) + self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中 if tensor.numel() > chunk_size: chunk_size = tensor.numel() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 27a19c132..bf990d127 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,19 +33,22 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk # transfer state for p in params: + # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched + # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously + # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..2e96c22f3 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -125,7 +125,7 @@ class GeminiManager: self._async_works[chunk].wait() del self._async_works[chunk] else: - non_prefetched_chunks.append(chunk) + non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk return tuple(non_prefetched_chunks) def add_work(self, chunk: Chunk, work: dist.Work): @@ -154,6 +154,7 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 + # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0f92fa50..4c3d8dbe2 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,9 +45,9 @@ class PlacementPolicy(ABC): raise NotImplementedError -import os - -rank = int(os.environ["RANK"]) +# import torch.distributed as dist +# # rank = int(os.environ["RANK"]) +# rank = dist.get_rank() class StaticPlacementPolicy(PlacementPolicy): @@ -118,8 +118,10 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] + # 最多有多少个异步的work can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] + # static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块 for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: @@ -238,7 +240,9 @@ class AutoPlacementPolicy(PlacementPolicy): grads_device_map[p] = torch.device("cpu") def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - return [] # TODO @botbw: implement prefetching for auto + # TODO @haze188 @botbw: implement prefetching for auto + + return [] class PlacementPolicyFactory: