mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[gemini] use compute_chunk to find next chunk
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
||||
|
||||
|
||||
class GeminiManager:
|
||||
@@ -91,13 +91,13 @@ class GeminiManager:
|
||||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
"""Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._layout_time += time() - start
|
||||
|
||||
@@ -133,9 +133,9 @@ class GeminiManager:
|
||||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
self._compute_idx += 1
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
def sample_overall_data(self):
|
||||
@@ -156,6 +156,18 @@ class GeminiManager:
|
||||
return self._mem_stats_collector.cuda_margin_mem
|
||||
return None
|
||||
|
||||
@property
|
||||
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
||||
return self._compute_list
|
||||
|
||||
@property
|
||||
def compute_idx(self) -> int:
|
||||
return self._compute_idx
|
||||
|
||||
@property
|
||||
def placement_policy(self) -> PlacementPolicy:
|
||||
return self._placement_policy
|
||||
|
||||
@property
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
Reference in New Issue
Block a user