[Gemini] polish memstats collector (#1962)

This commit is contained in:
Jiarui Fang
2022-11-16 15:45:57 +08:00
committed by GitHub
parent fea3cb661c
commit c4739a725a
7 changed files with 201 additions and 174 deletions

View File

@@ -1,31 +1,39 @@
import functools
from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy
import itertools
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Iterator, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
from colossalai.utils import disposable, get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from colossalai.zero.utils import ZeroHook
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
from ._utils import (
cast_float_arguments,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
free_storage,
get_gradient_predivide_factor,
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
@@ -49,7 +57,7 @@ class ShardedModelV2(nn.Module):
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
@@ -60,10 +68,10 @@ class ShardedModelV2(nn.Module):
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
Defaults to 'cuda'.
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
"""
@@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
if self.user_static_memstats:
self._memstats_collector = MemStatsCollectorStatic(self.module)
self._memstats_collector = StaticMemStatsCollector(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)