[gemini] use compute_chunk to find next chunk

This commit is contained in:
hxwang
2024-05-16 13:17:26 +08:00
parent b2e9745888
commit 4148ceed9f
5 changed files with 52 additions and 79 deletions

View File

@@ -6,7 +6,7 @@ import torch
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
@@ -91,13 +91,13 @@ class GeminiManager:
self._warmup = False
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
@@ -133,9 +133,9 @@ class GeminiManager:
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks)
def sample_overall_data(self):
@@ -156,6 +156,18 @@ class GeminiManager:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
@property
def compute_idx(self) -> int:
return self._compute_idx
@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats