diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 36dae1fc0..781ffe771 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from colossalai.gemini.chunk import Chunk, ChunkManager -from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic +from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector from .placement_policy import PlacementPolicyFactory @@ -26,7 +26,8 @@ class GeminiManager: chunk_manager (ChunkManager): A ``ChunkManager`` instance. """ - def __init__(self, placement_policy: str, + def __init__(self, + placement_policy: str, chunk_manager: ChunkManager, module: Optional[torch.nn.Module] = None, use_static_memstats: bool = False) -> None: @@ -35,14 +36,14 @@ class GeminiManager: self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager - # self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None + # self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None self.use_static_memstats = use_static_memstats if policy_cls.need_mem_stats: if use_static_memstats: assert module is not None - self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager) + self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager) else: - self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) else: self._mem_stats_collector = None diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index 21b3e17b9..d12461353 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -1,5 +1,10 @@ -from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor -from .memstats_collector import MemStatsCollector +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .static_memstats_collector import StaticMemStatsCollector # isort:skip -__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'] +__all__ = [ + 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', + 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' +] diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py new file mode 100644 index 000000000..4fbc1a477 --- /dev/null +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -0,0 +1,25 @@ +from colossalai.gemini.chunk import ChunkManager +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + +from .memstats_collector import MemStatsCollector + + +class ChunkMemStatsCollector(MemStatsCollector): + + def __init__(self, chunk_manager: ChunkManager) -> None: + super().__init__() + self._chunk_manager = chunk_manager + + def sample_model_data(self) -> None: + """Sampling model data statistics. + """ + if self._start_flag: + cuda_mem = self._chunk_manager.total_mem['cuda'] + cpu_mem = self._chunk_manager.total_mem['cpu'] + self._model_data_cuda_list.append(cuda_mem) + self._model_data_cpu_list.append(cpu_mem) + + @property + def cuda_margin_mem(self) -> float: + return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 836bb716d..5074f3f32 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -1,26 +1,17 @@ -from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity -from colossalai.utils import get_current_device -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.chunk import ChunkManager +import time +from typing import List import torch -import torch.nn as nn -import time -from typing import List, Optional -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) -from torch.fx import symbolic_trace - -if is_compatible_with_meta(): - from colossalai.fx.profiler import MetaTensor +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.utils.memory import colo_device_memory_used class MemStatsCollector: """ A Memory statistic collector. - It works in two phases. + It works in two phases. Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU. The first iteration of DNN training. Phase 2. Runtime Phase: use the read-only collected stats @@ -138,121 +129,3 @@ class MemStatsCollector: self._start_flag = False self._step_idx = 0 self._step_total = 0 - - -class MemStatsCollectorV2(MemStatsCollector): - - def __init__(self, chunk_manager: ChunkManager) -> None: - super().__init__() - self._chunk_manager = chunk_manager - - def sample_model_data(self) -> None: - """Sampling model data statistics. - """ - if self._start_flag: - cuda_mem = self._chunk_manager.total_mem['cuda'] - cpu_mem = self._chunk_manager.total_mem['cpu'] - self._model_data_cuda_list.append(cuda_mem) - self._model_data_cpu_list.append(cpu_mem) - - @property - def cuda_margin_mem(self) -> float: - return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) - - -class MemStatsCollectorStatic(MemStatsCollectorV2): - """ - A Static Memory statistic collector. - """ - - def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: - super().__init__(chunk_manager) - self.module = module - self.module_info_list = [] - - - def init_mem_stats(self, *inputs): - - self.register_opnodes_recursively(self.module) - self.refactor_module() - - self.module = self.module.cpu() - self.module.train() - - data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] - gm = symbolic_trace(self.module) - interp = MetaInfoProp(gm) - interp.propagate(*data) - - total_mem = 0 - for inp in inputs: - total_mem += inp.numel() * inp.element_size() - last_node = None - module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] - for node in gm.graph.nodes: - total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) - if node.op == "call_module": - if node.name.endswith("_0") and node.name[:-2] in module_name_list: - self._non_model_data_cuda_list.append(total_mem) - last_node = node - self._non_model_data_cuda_list.append(total_mem) - self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] - - cur_module_mem_fwd = 0 - cur_module_mem_bwd = 0 - grad_module_out = last_node.meta["fwd_mem_out"] - for node in gm.graph.nodes.__reversed__(): - cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) - cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] - if node.op == "call_module": - if node.name.endswith("_0") and node.name[:-2] in module_name_list: - self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) - total_mem = total_mem - cur_module_mem_fwd - cur_module_mem_fwd = 0 - cur_module_mem_bwd = 0 - grad_module_out = node.meta["bwd_mem_out"] - - self._step_total = len(self._non_model_data_cuda_list) - self.recover_module() - - - def refactor_module(self): - for modInfo in self.module_info_list: - temp_node = nn.Sequential(nn.ReLU(), modInfo.module) - modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) - - - def recover_module(self): - for modInfo in self.module_info_list: - modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) - - - def register_opnodes_recursively(self, - module: torch.nn.Module, - name: str = "", - full_name: str = "", - parent_module: Optional[torch.nn.Module] = None): - - assert isinstance(module, torch.nn.Module) - - for child_name, child in module.named_children(): - self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) - - # Early return on modules with no parameters. - if len(list(module.parameters(recurse=False))) == 0: - return - - self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) - - -class ModuleInfos: - - def __init__(self, - module: torch.nn.Module, - module_name: str, - module_full_name: str, - parent_module: torch.nn.Module): - self.module = module - self.module_name = module_name - self.module_full_name = module_full_name - self.parent_module = parent_module \ No newline at end of file diff --git a/colossalai/gemini/memory_tracer/static_memstats_collector.py b/colossalai/gemini/memory_tracer/static_memstats_collector.py new file mode 100644 index 000000000..3209881e1 --- /dev/null +++ b/colossalai/gemini/memory_tracer/static_memstats_collector.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta +from colossalai.gemini.chunk import ChunkManager + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +from .chunk_memstats_collector import ChunkMemStatsCollector + + +class ModuleInfos: + + def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, + parent_module: torch.nn.Module): + self.module = module + self.module_name = module_name + self.module_full_name = module_full_name + self.parent_module = parent_module + + +class StaticMemStatsCollector(ChunkMemStatsCollector): + """ + A Static Memory statistic collector. + """ + + def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(chunk_manager) + self.module = module + self.module_info_list = [] + + def init_mem_stats(self, *inputs): + + self.register_opnodes_recursively(self.module) + self.refactor_module() + + self.module = self.module.cpu() + self.module.train() + + data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + gm = symbolic_trace(self.module) + interp = MetaInfoProp(gm) + interp.propagate(*data) + + total_mem = 0 + for inp in inputs: + total_mem += inp.numel() * inp.element_size() + last_node = None + module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] + for node in gm.graph.nodes: + total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem) + last_node = node + self._non_model_data_cuda_list.append(total_mem) + self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] + + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = last_node.meta["fwd_mem_out"] + for node in gm.graph.nodes.__reversed__(): + cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) + cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) + total_mem = total_mem - cur_module_mem_fwd + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = node.meta["bwd_mem_out"] + + self._step_total = len(self._non_model_data_cuda_list) + self.recover_module() + + def refactor_module(self): + for modInfo in self.module_info_list: + temp_node = nn.Sequential(nn.ReLU(), modInfo.module) + modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) + + def recover_module(self): + for modInfo in self.module_info_list: + modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) + + def register_opnodes_recursively(self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None): + + assert isinstance(module, torch.nn.Module) + + for child_name, child in module.named_children(): + self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index ab1988b11..50004ec35 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -1,22 +1,24 @@ +import functools from abc import ABC, abstractmethod from time import time -from typing import List, Optional, Tuple, Dict +from typing import Dict, List, Optional, Tuple, Type + import torch + +from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import ChunkMemStatsCollector from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2 -from typing import Type -import functools -from colossalai.gemini.chunk import Chunk, ChunkManager - class PlacementPolicy(ABC): need_mem_stats: bool = False - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: self.chunk_manager = chunk_manager - self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector + self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -29,7 +31,9 @@ class PlacementPolicy(ABC): class CPUPlacementPolicy(PlacementPolicy): - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy): class CUDAPlacementPolicy(PlacementPolicy): - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) @@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy): _warmup_non_model_data_ratio: float = 0.8 _steady_cuda_cap_ratio: float = 0.9 - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, @@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy): need_mem_stats: bool = False _accessed_memory_boundary = 512 * 1024**2 - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index d86c31134..bbc2b1d25 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,31 +1,39 @@ import functools -from collections import OrderedDict -from typing import Any, Optional, Iterator, Tuple -from copy import deepcopy import itertools +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Iterator, Optional, Tuple + import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector from colossalai.gemini.ophooks import register_ophooks_recursively -from colossalai.zero.utils import ZeroHook from colossalai.gemini.paramhooks import BaseParamHookMgr +from colossalai.gemini.stateful_tensor import TensorState +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device, disposable -from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic +from colossalai.utils import disposable, get_current_device from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter -from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy +from colossalai.zero.utils import ZeroHook -from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, - get_gradient_predivide_factor) +from ._utils import ( + cast_float_arguments, + cast_tensor_to_fp16, + cast_tensor_to_fp32, + chunk_and_pad, + free_storage, + get_gradient_predivide_factor, +) try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX @@ -49,7 +57,7 @@ class ShardedModelV2(nn.Module): module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. - reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. @@ -60,10 +68,10 @@ class ShardedModelV2(nn.Module): Note that 'auto' policy can only work well when no other processes use CUDA during your training. Defaults to 'cuda'. gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. - reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. - Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. - In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). - We find that PyTorch's optimizers don't support mixed precision, + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. """ @@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module): self._use_memory_tracer = tensor_placement_policy == 'auto' if self._use_memory_tracer: if self.user_static_memstats: - self._memstats_collector = MemStatsCollectorStatic(self.module) + self._memstats_collector = StaticMemStatsCollector(self.module) else: self._memstats_collector = MemStatsCollector() self._start_collect_memstats = disposable(self._memstats_collector.start_collection)