[chore] refactor & sync

This commit is contained in:
hxwang
2024-05-16 07:22:10 +00:00
parent 4148ceed9f
commit 2e68eebdfe
7 changed files with 82 additions and 46 deletions

View File

@@ -1,7 +1,7 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Dict, List
from typing import Dict, List, Iterable, Tuple
import torch
import torch.distributed as dist
@@ -22,45 +22,55 @@ class TrainingPhase(Enum):
logger = DistributedLogger("gemini_hook")
import os
rank = int(os.environ['RANK'])
class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
print(f"prefetched {chunk.count_id}")
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
unique_chunks = set(all_chunks)
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
# fetch the rest chunks synchronously
# evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0)
# fetch the rest synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk)
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
if rank == 0 and not self._gemini_manager.is_warmup():
print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}")
print(f"{all_chunks=}")
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
print(f"{chunks_fetch_sync=}")
print(f"{chunks_fetch_async=}")
print(f"works={list(self._gemini_manager._async_works.keys())}")
# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._async_works[chunk] = maybe_work
self._gemini_manager.add_work(chunk, maybe_work)
if rank == 0 and not self._gemini_manager.is_warmup():
print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}")
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
@@ -88,11 +98,6 @@ class GeminiZeROHook(ColoParamOpHook):
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
if training_phase == TrainingPhase.FORWARD:
self._cur_param_idx = 0
else:
self._cur_param_idx = len(self._param_visited_order) - 1
old_training_phase = self._training_phase
try:
self._training_phase = training_phase