mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[chore] refactor & sync
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -22,45 +22,55 @@ class TrainingPhase(Enum):
|
||||
|
||||
logger = DistributedLogger("gemini_hook")
|
||||
|
||||
import os
|
||||
rank = int(os.environ['RANK'])
|
||||
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
self._gemini_manager = gemini_manager
|
||||
self._chunk_manager = gemini_manager.chunk_manager
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._max_prefetch = max_prefetch
|
||||
self._async_works: Dict[Chunk, dist.work] = {}
|
||||
|
||||
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
||||
non_prefetched_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk in self._async_works:
|
||||
print(f"prefetched {chunk.count_id}")
|
||||
self._async_works[chunk].wait()
|
||||
del self._async_works[chunk]
|
||||
else:
|
||||
non_prefetched_chunks.append(chunk)
|
||||
return non_prefetched_chunks
|
||||
|
||||
def pre_op(self, params):
|
||||
# map params to chunks
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
all_chunks = self._chunk_manager.get_chunks(params)
|
||||
|
||||
# wait for prefetched chunks, filter those are not prefetched
|
||||
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
|
||||
unique_chunks = set(all_chunks)
|
||||
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
||||
|
||||
# transfer state
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
|
||||
# fetch the rest chunks synchronously
|
||||
|
||||
# evit chunks, aware of async fetched
|
||||
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0)
|
||||
|
||||
# fetch the rest synchronously
|
||||
for chunk in chunks_fetch_sync:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
|
||||
|
||||
# get possible chunks to prefetch
|
||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
||||
if rank == 0 and not self._gemini_manager.is_warmup():
|
||||
print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}")
|
||||
print(f"{all_chunks=}")
|
||||
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
|
||||
print(f"{chunks_fetch_sync=}")
|
||||
print(f"{chunks_fetch_async=}")
|
||||
print(f"works={list(self._gemini_manager._async_works.keys())}")
|
||||
|
||||
# prefetch
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._async_works[chunk] = maybe_work
|
||||
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
if rank == 0 and not self._gemini_manager.is_warmup():
|
||||
print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}")
|
||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
@@ -88,11 +98,6 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
|
||||
@contextmanager
|
||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||
if training_phase == TrainingPhase.FORWARD:
|
||||
self._cur_param_idx = 0
|
||||
else:
|
||||
self._cur_param_idx = len(self._param_visited_order) - 1
|
||||
|
||||
old_training_phase = self._training_phase
|
||||
try:
|
||||
self._training_phase = training_phase
|
||||
|
Reference in New Issue
Block a user