mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[zero] refactor memstats_collector (#746)
This commit is contained in:
parent
b8899e0905
commit
84c6700b2a
@ -1,3 +1,4 @@
|
|||||||
from .utils import register_ophooks_recursively, BaseOpHook
|
from .utils import register_ophooks_recursively, BaseOpHook
|
||||||
|
from ._memtracer_ophook import MemTracerOpHook
|
||||||
|
|
||||||
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer import SyncCudaMemoryMonitor
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -19,7 +19,7 @@ class MemStatsCollector:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._mem_monitor = AsyncMemoryMonitor()
|
self._mem_monitor = SyncCudaMemoryMonitor()
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
self._overall_cuda_list = []
|
self._overall_cuda_list = []
|
||||||
|
|
||||||
@ -31,9 +31,10 @@ class MemStatsCollector:
|
|||||||
self._sampling_time = []
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._period_idx = 0
|
self._step_idx = 0
|
||||||
|
self._step_total = 0
|
||||||
|
|
||||||
def overall_mem_stats(self, device_type: str):
|
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return self._overall_cuda_list
|
return self._overall_cuda_list
|
||||||
elif device_type == 'cpu':
|
elif device_type == 'cpu':
|
||||||
@ -41,47 +42,23 @@ class MemStatsCollector:
|
|||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
def model_data_list(self, device_type: str) -> List[int]:
|
||||||
if unit == 'GB':
|
|
||||||
scale = 1e9
|
|
||||||
elif unit == 'MB':
|
|
||||||
scale = 1e6
|
|
||||||
elif unit == 'KB':
|
|
||||||
scale = 1e3
|
|
||||||
elif unit == 'B':
|
|
||||||
scale = 1
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return [elem / scale for elem in self._model_data_cuda_list]
|
return self._model_data_cuda_list
|
||||||
elif device_type == 'cpu':
|
elif device_type == 'cpu':
|
||||||
return [elem / scale for elem in self._model_data_cpu_list]
|
return self._model_data_cpu_list
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
|
||||||
"""Non model data stats
|
|
||||||
"""
|
|
||||||
if unit == 'GB':
|
|
||||||
scale = 1e9
|
|
||||||
elif unit == 'MB':
|
|
||||||
scale = 1e6
|
|
||||||
elif unit == 'KB':
|
|
||||||
scale = 1e3
|
|
||||||
elif unit == 'B':
|
|
||||||
scale = 1
|
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return [elem / scale for elem in self._non_model_data_cuda_list]
|
return self._non_model_data_cuda_list
|
||||||
elif device_type == 'cpu':
|
elif device_type == 'cpu':
|
||||||
return [elem / scale for elem in self._non_model_data_cpu_list]
|
return self._non_model_data_cpu_list
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def max_non_model_data(self, device_type: str) -> int:
|
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||||
"""Get max non model data memory usage of current sampling period
|
"""Get max non model data memory usage of current sampling period
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -91,12 +68,10 @@ class MemStatsCollector:
|
|||||||
int: max non model data memory usage of current sampling period
|
int: max non model data memory usage of current sampling period
|
||||||
"""
|
"""
|
||||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||||
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
|
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||||
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
|
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
|
||||||
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
|
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||||
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
|
return next_non_model_data
|
||||||
self._period_idx = next_period_idx
|
|
||||||
return max(current_non_model_data, next_non_model_data)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sampling_time(self):
|
def sampling_time(self):
|
||||||
@ -107,9 +82,37 @@ class MemStatsCollector:
|
|||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
|
||||||
def finish_collection(self):
|
def finish_collection(self):
|
||||||
|
self.sample_overall_data()
|
||||||
|
self._step_total = len(self._sampling_time)
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._mem_monitor.finish()
|
self._mem_monitor.finish()
|
||||||
|
|
||||||
|
def sample_model_data(self) -> None:
|
||||||
|
"""Sampling model data statistics.
|
||||||
|
"""
|
||||||
|
if self._start_flag:
|
||||||
|
cuda_mem, cpu_mem = GLOBAL_MODEL_DATA_TRACER.both_mem_usage
|
||||||
|
self._model_data_cuda_list.append(cuda_mem)
|
||||||
|
self._model_data_cpu_list.append(cpu_mem)
|
||||||
|
|
||||||
|
def sample_overall_data(self) -> None:
|
||||||
|
"""Sampling non model data statistics.
|
||||||
|
"""
|
||||||
|
if self._start_flag:
|
||||||
|
# overall data recording is after model data recording
|
||||||
|
if len(self._model_data_cuda_list) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||||
|
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
|
||||||
|
|
||||||
|
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
|
||||||
|
|
||||||
|
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||||
|
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||||
|
self._sampling_time.append(time.time())
|
||||||
|
self._mem_monitor.start()
|
||||||
|
|
||||||
def sample_memstats(self) -> None:
|
def sample_memstats(self) -> None:
|
||||||
"""
|
"""
|
||||||
Sampling memory statistics.
|
Sampling memory statistics.
|
||||||
@ -119,7 +122,7 @@ class MemStatsCollector:
|
|||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||||
|
|
||||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||||
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
||||||
@ -136,4 +139,5 @@ class MemStatsCollector:
|
|||||||
self._overall_cpu_list = []
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._period_idx = 0
|
self._step_idx = 0
|
||||||
|
self._step_total = 0
|
||||||
|
@ -101,5 +101,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||||||
cuda_usage, _ = self._get_mem_usage()
|
cuda_usage, _ = self._get_mem_usage()
|
||||||
return cuda_usage
|
return cuda_usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def both_mem_usage(self):
|
||||||
|
return self._get_mem_usage()
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||||
|
@ -109,6 +109,5 @@ class ShardedParamV2(object):
|
|||||||
|
|
||||||
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.param.grad)
|
_update_mem_use(self.param.grad)
|
||||||
address_set.add(self.param.grad.data_ptr())
|
|
||||||
|
|
||||||
return cuda_mem_use, cpu_mem_use
|
return cuda_mem_use, cpu_mem_use
|
||||||
|
@ -13,7 +13,7 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
|
|||||||
|
|
||||||
cuda_use, cpu_use = 0, 0
|
cuda_use, cpu_use = 0, 0
|
||||||
|
|
||||||
mem_use = t.numel() * t.element_size()
|
mem_use = t.storage().size() * t.element_size()
|
||||||
if t.device.type == 'cuda':
|
if t.device.type == 'cuda':
|
||||||
cuda_use += mem_use
|
cuda_use += mem_use
|
||||||
elif t.device.type == 'cpu':
|
elif t.device.type == 'cpu':
|
||||||
|
@ -38,10 +38,6 @@ class StatefulTensorMgr(object):
|
|||||||
def adjust_layout(self) -> None:
|
def adjust_layout(self) -> None:
|
||||||
""" Adjust the layout of statefuil tensor according to the information provided
|
""" Adjust the layout of statefuil tensor according to the information provided
|
||||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||||
|
|
||||||
Args:
|
|
||||||
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
|
|
||||||
It contains non-model footprint of a DNN model.
|
|
||||||
"""
|
"""
|
||||||
# find stateful tensor in state COMPUTE
|
# find stateful tensor in state COMPUTE
|
||||||
cuda_demand = 0
|
cuda_demand = 0
|
||||||
|
@ -61,7 +61,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||||
else:
|
else:
|
||||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.max_non_model_data('cuda')
|
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||||
if avail_cuda_model_data < cuda_demand:
|
if avail_cuda_model_data < cuda_demand:
|
||||||
@ -71,7 +71,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
freed_cuda_model_data = 0
|
freed_cuda_model_data = 0
|
||||||
to_free_tensor_list = hold_cuda_tensor_list
|
to_free_tensor_list = hold_cuda_tensor_list
|
||||||
if not warmup:
|
if not warmup:
|
||||||
next_compute_idx: Dict[StatefulTensor, int] = {t: len(compute_list) for t in hold_cuda_tensor_list}
|
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensor_list}
|
||||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||||
if compute_list[i] in next_compute_idx:
|
if compute_list[i] in next_compute_idx:
|
||||||
next_compute_idx[compute_list[i]] = i
|
next_compute_idx[compute_list[i]] = i
|
||||||
|
@ -36,17 +36,7 @@ class ZeroHook(BaseOpHook):
|
|||||||
self._memstarts_collector = memstarts_collector
|
self._memstarts_collector = memstarts_collector
|
||||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def gather_parameters(self, module: torch.nn.Module):
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
|
||||||
|
|
||||||
if self._stateful_tensor_mgr:
|
|
||||||
self._stateful_tensor_mgr.adjust_layout()
|
|
||||||
else:
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
|
||||||
|
|
||||||
# gather sharded parameters
|
# gather sharded parameters
|
||||||
if module.param_is_sharded:
|
if module.param_is_sharded:
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
@ -55,10 +45,33 @@ class ZeroHook(BaseOpHook):
|
|||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||||
|
|
||||||
# record memory statistics
|
def shard_parameters(self, module: torch.nn.Module):
|
||||||
if self._memstarts_collector:
|
# shard gathered parameters
|
||||||
self._memstarts_collector.sample_memstats()
|
if module.param_is_sharded:
|
||||||
|
tensor_list = []
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
assert hasattr(param, 'colo_attr')
|
||||||
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
|
def adjust_module_data(self, module: torch.nn.Module):
|
||||||
|
# record overall data statistics
|
||||||
|
if self._memstarts_collector:
|
||||||
|
self._memstarts_collector.sample_overall_data()
|
||||||
|
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
|
# adjust stateful tensor to get enough CUDA memory
|
||||||
|
self._stateful_tensor_mgr.adjust_layout()
|
||||||
|
|
||||||
|
# record model data statistics
|
||||||
|
if self._memstarts_collector:
|
||||||
|
self._memstarts_collector.sample_model_data()
|
||||||
|
|
||||||
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
self.adjust_module_data(module)
|
||||||
|
self.gather_parameters(module)
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.data = param.colo_attr.data_payload
|
param.data = param.colo_attr.data_payload
|
||||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||||
@ -69,41 +82,15 @@ class ZeroHook(BaseOpHook):
|
|||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||||
|
|
||||||
# shard gathered parameters
|
self.shard_parameters(module)
|
||||||
if module.param_is_sharded:
|
|
||||||
tensor_list = []
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
assert hasattr(param, 'colo_attr')
|
|
||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
||||||
|
|
||||||
# remove torch payload
|
# remove torch payload
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.colo_attr.set_data_none()
|
param.colo_attr.set_data_none()
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
|
self.adjust_module_data(module)
|
||||||
for param in module.parameters(recurse=False):
|
self.gather_parameters(module)
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
|
||||||
|
|
||||||
if self._stateful_tensor_mgr:
|
|
||||||
self._stateful_tensor_mgr.adjust_layout()
|
|
||||||
else:
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
|
||||||
|
|
||||||
# gather sharded parameters
|
|
||||||
if module.param_is_sharded:
|
|
||||||
tensor_list = []
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
assert hasattr(param, 'colo_attr')
|
|
||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
|
||||||
|
|
||||||
# record memory statistics
|
|
||||||
if self._memstarts_collector:
|
|
||||||
self._memstarts_collector.sample_memstats()
|
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.data = param.colo_attr.data_payload
|
param.data = param.colo_attr.data_payload
|
||||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||||
@ -114,13 +101,7 @@ class ZeroHook(BaseOpHook):
|
|||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
|
|
||||||
# shard gathered parameters
|
self.shard_parameters(module)
|
||||||
if module.param_is_sharded:
|
|
||||||
tensor_list = []
|
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
assert hasattr(param, 'colo_attr')
|
|
||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
||||||
|
|
||||||
# remove torch payload
|
# remove torch payload
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
|
74
tests/test_zero/test_mem_collector.py
Normal file
74
tests/test_zero/test_mem_collector.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
|
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.proj1 = nn.Linear(512, 512)
|
||||||
|
self.weight = nn.Parameter(torch.randn(1024, 512))
|
||||||
|
self.proj2 = nn.Linear(1024, 512)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj1(x)
|
||||||
|
x = F.linear(x, self.weight)
|
||||||
|
x = self.proj2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def run_mem_collector_testing():
|
||||||
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
|
fraction = (50 * 1024**2) / cuda_capacity
|
||||||
|
# limit max memory to 50MB
|
||||||
|
colo_set_process_memory_fraction(fraction)
|
||||||
|
shard_strategy = BucketTensorShardStrategy()
|
||||||
|
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||||
|
model = TestModel()
|
||||||
|
|
||||||
|
model = ShardedModelV2(module=model,
|
||||||
|
shard_strategy=shard_strategy,
|
||||||
|
reduce_scatter_bucket_size_mb=1,
|
||||||
|
tensor_placement_policy='auto')
|
||||||
|
|
||||||
|
data = torch.randn(2, 512, device=get_current_device())
|
||||||
|
|
||||||
|
output = model(data)
|
||||||
|
loss = torch.mean(output)
|
||||||
|
model.backward(loss)
|
||||||
|
|
||||||
|
cuda_model_data_list = model._memstats_collector.model_data_list('cuda')
|
||||||
|
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
||||||
|
|
||||||
|
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda')
|
||||||
|
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
||||||
|
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_mem_collector_testing()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
|
def test_mem_collector(world_size=2):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_mem_collector()
|
@ -48,30 +48,39 @@ def run_stm():
|
|||||||
# warmup
|
# warmup
|
||||||
# use naive eviction strategy
|
# use naive eviction strategy
|
||||||
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
mem_collector.finish_collection()
|
mem_collector.finish_collection()
|
||||||
stateful_tensor_mgr.reset()
|
stateful_tensor_mgr.reset()
|
||||||
|
|
||||||
# warmup done
|
# warmup done
|
||||||
# use OPT-like eviction strategy
|
# use OPT-like eviction strategy
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.sample_overall_data()
|
||||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_model_data()
|
||||||
|
mem_collector.finish_collection()
|
||||||
|
|
||||||
|
|
||||||
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
||||||
|
Loading…
Reference in New Issue
Block a user