mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 21:25:53 +00:00
add some todo Message
This commit is contained in:
parent
013690a86b
commit
3d625ca836
@ -83,7 +83,7 @@ class ChunkManager:
|
|||||||
if chunk_group:
|
if chunk_group:
|
||||||
# the chunk group is not empty
|
# the chunk group is not empty
|
||||||
# close the last chunk
|
# close the last chunk
|
||||||
self.__close_one_chunk(chunk_group[-1])
|
self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中
|
||||||
|
|
||||||
if tensor.numel() > chunk_size:
|
if tensor.numel() > chunk_size:
|
||||||
chunk_size = tensor.numel()
|
chunk_size = tensor.numel()
|
||||||
|
@ -33,19 +33,22 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||||||
all_chunks = self._chunk_manager.get_chunks(params)
|
all_chunks = self._chunk_manager.get_chunks(params)
|
||||||
|
|
||||||
# wait for prefetched chunks, filter those are not prefetched
|
# wait for prefetched chunks, filter those are not prefetched
|
||||||
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk
|
||||||
|
|
||||||
# transfer state
|
# transfer state
|
||||||
for p in params:
|
for p in params:
|
||||||
|
# TODO(haze188): check状态转换
|
||||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||||
self._gemini_manager.sample_overall_data()
|
self._gemini_manager.sample_overall_data()
|
||||||
|
|
||||||
# evit chunks, aware of async fetched
|
# evit chunks, aware of async fetched
|
||||||
|
# TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下
|
||||||
self._gemini_manager.adjust_layout(
|
self._gemini_manager.adjust_layout(
|
||||||
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# fetch the rest synchronously
|
# fetch the rest synchronously
|
||||||
|
# TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步)
|
||||||
for chunk in chunks_fetch_sync:
|
for chunk in chunks_fetch_sync:
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class GeminiManager:
|
|||||||
self._async_works[chunk].wait()
|
self._async_works[chunk].wait()
|
||||||
del self._async_works[chunk]
|
del self._async_works[chunk]
|
||||||
else:
|
else:
|
||||||
non_prefetched_chunks.append(chunk)
|
non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk
|
||||||
return tuple(non_prefetched_chunks)
|
return tuple(non_prefetched_chunks)
|
||||||
|
|
||||||
def add_work(self, chunk: Chunk, work: dist.Work):
|
def add_work(self, chunk: Chunk, work: dist.Work):
|
||||||
@ -154,6 +154,7 @@ class GeminiManager:
|
|||||||
|
|
||||||
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||||
self._compute_idx += 1
|
self._compute_idx += 1
|
||||||
|
# TODO(haze188): _compute_list 记录块的访问顺序
|
||||||
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||||
self._compute_list.append(chunks)
|
self._compute_list.append(chunks)
|
||||||
|
|
||||||
|
@ -45,9 +45,9 @@ class PlacementPolicy(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
import os
|
# import torch.distributed as dist
|
||||||
|
# # rank = int(os.environ["RANK"])
|
||||||
rank = int(os.environ["RANK"])
|
# rank = dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
class StaticPlacementPolicy(PlacementPolicy):
|
class StaticPlacementPolicy(PlacementPolicy):
|
||||||
@ -118,8 +118,10 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|||||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||||
return []
|
return []
|
||||||
|
# 最多有多少个异步的work
|
||||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||||
prefetch = []
|
prefetch = []
|
||||||
|
# static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块
|
||||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||||
for chunk in self.gemini_manager.compute_list[i]:
|
for chunk in self.gemini_manager.compute_list[i]:
|
||||||
if len(prefetch) >= can_prefetch:
|
if len(prefetch) >= can_prefetch:
|
||||||
@ -238,7 +240,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||||||
grads_device_map[p] = torch.device("cpu")
|
grads_device_map[p] = torch.device("cpu")
|
||||||
|
|
||||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||||
return [] # TODO @botbw: implement prefetching for auto
|
# TODO @haze188 @botbw: implement prefetching for auto
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class PlacementPolicyFactory:
|
class PlacementPolicyFactory:
|
||||||
|
Loading…
Reference in New Issue
Block a user