[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-05-16 07:26:19 +00:00
parent 82b25524ff
commit 6bbe956316
4 changed files with 32 additions and 17 deletions

View File

@@ -1,10 +1,9 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Dict, List, Iterable, Tuple
from typing import List
import torch
import torch.distributed as dist
from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook
@@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from .chunk import Chunk
class TrainingPhase(Enum):
FORWARD = 0
@@ -23,7 +20,9 @@ class TrainingPhase(Enum):
logger = DistributedLogger("gemini_hook")
import os
rank = int(os.environ['RANK'])
rank = int(os.environ["RANK"])
class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
@@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook):
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched
unique_chunks = set(all_chunks)
set(all_chunks)
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state
@@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.sample_overall_data()
# evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0)
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
# fetch the rest synchronously
for chunk in chunks_fetch_sync:
@@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook):
# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
if rank == 0 and not self._gemini_manager.is_warmup():
print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}")
print(
f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}"
)
print(f"{all_chunks=}")
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
print(f"{chunks_fetch_sync=}")