diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index d07588b08..36dae1fc0 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from colossalai.gemini.chunk import Chunk, ChunkManager -from .memory_tracer.memstats_collector import MemStatsCollectorV2 +from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic from .placement_policy import PlacementPolicyFactory @@ -26,12 +26,26 @@ class GeminiManager: chunk_manager (ChunkManager): A ``ChunkManager`` instance. """ - def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: + def __init__(self, placement_policy: str, + chunk_manager: ChunkManager, + module: Optional[torch.nn.Module] = None, + use_static_memstats: bool = False) -> None: + assert placement_policy in PlacementPolicyFactory.get_polocy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager - self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None + # self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None + self.use_static_memstats = use_static_memstats + if policy_cls.need_mem_stats: + if use_static_memstats: + assert module is not None + self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager) + else: + self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) + else: + self._mem_stats_collector = None + self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -43,9 +57,13 @@ class GeminiManager: self._warmup = True self._comp_cuda_demand_time = 0 - def pre_iter(self): + def pre_iter(self, *args): if self._mem_stats_collector and self._warmup: - self._mem_stats_collector.start_collection() + if self.use_static_memstats: + self._mem_stats_collector.init_mem_stats(*args) + self._warmup = False + else: + self._mem_stats_collector.start_collection() def post_iter(self): """This function must be called when each iteration finishes diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 4366956fe..836bb716d 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.chunk import ChunkManager import torch +import torch.nn as nn import time -from typing import List +from typing import List, Optional + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) +from torch.fx import symbolic_trace + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor class MemStatsCollector: @@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector): @property def cuda_margin_mem(self) -> float: return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) + + +class MemStatsCollectorStatic(MemStatsCollectorV2): + """ + A Static Memory statistic collector. + """ + + def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(chunk_manager) + self.module = module + self.module_info_list = [] + + + def init_mem_stats(self, *inputs): + + self.register_opnodes_recursively(self.module) + self.refactor_module() + + self.module = self.module.cpu() + self.module.train() + + data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + gm = symbolic_trace(self.module) + interp = MetaInfoProp(gm) + interp.propagate(*data) + + total_mem = 0 + for inp in inputs: + total_mem += inp.numel() * inp.element_size() + last_node = None + module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] + for node in gm.graph.nodes: + total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem) + last_node = node + self._non_model_data_cuda_list.append(total_mem) + self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] + + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = last_node.meta["fwd_mem_out"] + for node in gm.graph.nodes.__reversed__(): + cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) + cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) + total_mem = total_mem - cur_module_mem_fwd + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = node.meta["bwd_mem_out"] + + self._step_total = len(self._non_model_data_cuda_list) + self.recover_module() + + + def refactor_module(self): + for modInfo in self.module_info_list: + temp_node = nn.Sequential(nn.ReLU(), modInfo.module) + modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) + + + def recover_module(self): + for modInfo in self.module_info_list: + modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) + + + def register_opnodes_recursively(self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None): + + assert isinstance(module, torch.nn.Module) + + for child_name, child in module.named_children(): + self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) + + +class ModuleInfos: + + def __init__(self, + module: torch.nn.Module, + module_name: str, + module_full_name: str, + parent_module: torch.nn.Module): + self.module = module + self.module_name = module_name + self.module_full_name = module_full_name + self.parent_module = parent_module \ No newline at end of file diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index d58a746b6..0fb36d8af 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP): def forward(self, *args, **kwargs): args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True) - self.gemini_manager.pre_iter() + self.gemini_manager.pre_iter(*args) with ParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) if self.force_outputs_fp32: diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 7d5cfdae0..d86c31134 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device, disposable -from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer @@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module): tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False, + user_static_memstats: bool = False, *args, **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' @@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module): self.world_size = dist.get_world_size(self.process_group) self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy + self.user_static_memstats = user_static_memstats self._use_memory_tracer = tensor_placement_policy == 'auto' if self._use_memory_tracer: - self._memstats_collector = MemStatsCollector() + if self.user_static_memstats: + self._memstats_collector = MemStatsCollectorStatic(self.module) + else: + self._memstats_collector = MemStatsCollector() self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: @@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module): f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) f.write('\n') - def _pre_forward_operations(self): + def _pre_forward_operations(self, *args): # the operation will affect the memory tracer behavior in ZeroHook if self._memstats_collector: + if self.user_static_memstats: + self.init_mem_stats(*args) self._start_collect_memstats() for p in self.module.parameters(): @@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - self._pre_forward_operations() + self._pre_forward_operations(*args) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) self._post_forward_operations()