mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)
This commit is contained in:
parent
1f99205827
commit
33f4412102
@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
|
# override
|
||||||
def sample_model_data(self) -> None:
|
def sample_model_data(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||||
self._model_data_cuda_list.append(cuda_mem)
|
self._memstats.append_model_data('cuda', cuda_mem)
|
||||||
self._model_data_cpu_list.append(cpu_mem)
|
self._memstats.append_model_data('cpu', cpu_mem)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cuda_margin_mem(self) -> float:
|
def cuda_margin_mem(self) -> float:
|
||||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
|
||||||
|
94
colossalai/gemini/memory_tracer/memory_stats.py
Normal file
94
colossalai/gemini/memory_tracer/memory_stats.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class MemStats(object):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||||
|
"""
|
||||||
|
# p -> list of non_model data volumn visied in order.
|
||||||
|
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
||||||
|
|
||||||
|
self._model_data_cuda_list = []
|
||||||
|
self._model_data_cpu_list = []
|
||||||
|
|
||||||
|
self._overall_cuda_list = []
|
||||||
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
|
self._non_model_data_cuda_list = []
|
||||||
|
self._non_model_data_cpu_list = []
|
||||||
|
|
||||||
|
def append_overall_data(self, device_type: str, val: float):
|
||||||
|
if device_type == 'cuda':
|
||||||
|
self._overall_cuda_list.append(val)
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
self._overall_cpu_list.append(val)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def append_model_data(self, device_type: str, val: float):
|
||||||
|
if device_type == 'cuda':
|
||||||
|
self._model_data_cuda_list.append(val)
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
self._model_data_cpu_list.append(val)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def append_non_model_data(self, device_type: str):
|
||||||
|
if device_type == 'cuda':
|
||||||
|
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||||
|
if device_type == 'cuda':
|
||||||
|
return self._overall_cuda_list
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
return self._overall_cpu_list
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def model_data_list(self, device_type: str) -> List[int]:
|
||||||
|
if device_type == 'cuda':
|
||||||
|
return self._model_data_cuda_list
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
return self._model_data_cpu_list
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||||
|
if device_type == 'cuda':
|
||||||
|
return self._non_model_data_cuda_list
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
return self._non_model_data_cpu_list
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def max_non_model_data(self, device_type: str) -> float:
|
||||||
|
if device_type == 'cuda':
|
||||||
|
return max(self._non_model_data_cuda_list)
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
return max(self._non_model_data_cpu_list)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def max_overall_cuda(self, device_type: str) -> float:
|
||||||
|
if device_type == 'cuda':
|
||||||
|
return max(self._overall_cuda_list)
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
return max(self._overall_cpu_list)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._model_data_cuda_list = []
|
||||||
|
self._overall_cuda_list = []
|
||||||
|
|
||||||
|
self._model_data_cpu_list = []
|
||||||
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
|
self._non_model_data_cpu_list = []
|
||||||
|
self._non_model_data_cuda_list = []
|
@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|||||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
|
|
||||||
|
from .memory_stats import MemStats
|
||||||
|
|
||||||
|
|
||||||
class MemStatsCollector:
|
class MemStatsCollector:
|
||||||
"""
|
"""
|
||||||
@ -22,43 +24,12 @@ class MemStatsCollector:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._mem_monitor = SyncCudaMemoryMonitor()
|
self._mem_monitor = SyncCudaMemoryMonitor()
|
||||||
self._model_data_cuda_list = []
|
|
||||||
self._overall_cuda_list = []
|
|
||||||
|
|
||||||
self._model_data_cpu_list = []
|
|
||||||
self._overall_cpu_list = []
|
|
||||||
|
|
||||||
self._non_model_data_cuda_list = []
|
|
||||||
self._non_model_data_cpu_list = []
|
|
||||||
self._sampling_time = []
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._step_idx = 0
|
self._step_idx = 0
|
||||||
self._step_total = 0
|
self._step_total = 0
|
||||||
|
self._memstats = MemStats()
|
||||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._overall_cuda_list
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._overall_cpu_list
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def model_data_list(self, device_type: str) -> List[int]:
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._model_data_cuda_list
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._model_data_cpu_list
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._non_model_data_cuda_list
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._non_model_data_cpu_list
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||||
"""Get max non model data memory usage of current sampling period
|
"""Get max non model data memory usage of current sampling period
|
||||||
@ -71,7 +42,7 @@ class MemStatsCollector:
|
|||||||
"""
|
"""
|
||||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||||
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
|
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||||
return next_non_model_data
|
return next_non_model_data
|
||||||
|
|
||||||
@ -95,37 +66,29 @@ class MemStatsCollector:
|
|||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||||
self._model_data_cuda_list.append(cuda_mem)
|
self._memstats.append_model_data('cuda', cuda_mem)
|
||||||
self._model_data_cpu_list.append(cpu_mem)
|
self._memstats.append_model_data('cpu', cpu_mem)
|
||||||
|
|
||||||
def sample_overall_data(self) -> None:
|
def sample_overall_data(self) -> None:
|
||||||
"""Sampling non model data statistics.
|
"""Sampling non model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
# overall data recording is after model data recording
|
# overall data recording is after model data recording
|
||||||
if len(self._model_data_cuda_list) == 0:
|
if len(self._memstats._model_data_cuda_list) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
||||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
|
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
||||||
|
|
||||||
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
|
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
||||||
|
|
||||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
self._memstats.append_non_model_data('cuda')
|
||||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
self._memstats.append_non_model_data('cpu')
|
||||||
self._sampling_time.append(time.time())
|
self._sampling_time.append(time.time())
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._model_data_cuda_list = []
|
self._memstats.clear()
|
||||||
self._overall_cuda_list = []
|
|
||||||
|
|
||||||
self._model_data_cpu_list = []
|
|
||||||
self._overall_cpu_list = []
|
|
||||||
|
|
||||||
self._non_model_data_cpu_list = []
|
|
||||||
self._non_model_data_cuda_list = []
|
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._step_idx = 0
|
self._step_idx = 0
|
||||||
self._step_total = 0
|
self._step_total = 0
|
||||||
|
@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
|
|||||||
tensor_placement_policy: str = 'cuda',
|
tensor_placement_policy: str = 'cuda',
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
reuse_fp16_shard: bool = False,
|
reuse_fp16_shard: bool = False,
|
||||||
user_static_memstats: bool = False,
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||||
@ -119,13 +118,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.rank = dist.get_rank(self.process_group)
|
self.rank = dist.get_rank(self.process_group)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.user_static_memstats = user_static_memstats
|
|
||||||
|
|
||||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
if self.user_static_memstats:
|
|
||||||
self._memstats_collector = StaticMemStatsCollector(self.module)
|
|
||||||
else:
|
|
||||||
self._memstats_collector = MemStatsCollector()
|
self._memstats_collector = MemStatsCollector()
|
||||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||||
@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
|
|||||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||||
f.write('CUDA model data (GB)\n')
|
f.write('CUDA model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
f.write('CUDA non model data (GB)\n')
|
f.write('CUDA non model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
|
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||||
f.write('CPU non model data (GB)\n')
|
f.write('CPU non model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
def _pre_forward_operations(self, *args):
|
def _pre_forward_operations(self, *args):
|
||||||
# the operation will affect the memory tracer behavior in ZeroHook
|
# the operation will affect the memory tracer behavior in ZeroHook
|
||||||
if self._memstats_collector:
|
if self._memstats_collector:
|
||||||
if self.user_static_memstats:
|
|
||||||
self.init_mem_stats(*args)
|
|
||||||
self._start_collect_memstats()
|
self._start_collect_memstats()
|
||||||
|
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
|
|||||||
# model data is fixed in cuda during training.
|
# model data is fixed in cuda during training.
|
||||||
# cuda margin space can be used to store OS.
|
# cuda margin space can be used to store OS.
|
||||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||||
self._memstats_collector.overall_mem_stats('cuda'))
|
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _post_backward_operations(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
import torch
|
from functools import partial
|
||||||
import colossalai
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
|
||||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||||
from colossalai.utils import free_port
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
class MyTestModel(torch.nn.Module):
|
class MyTestModel(torch.nn.Module):
|
||||||
@ -50,10 +52,11 @@ def run_mem_collector_testing():
|
|||||||
loss = torch.mean(output)
|
loss = torch.mean(output)
|
||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
|
|
||||||
cuda_model_data_list = model._memstats_collector.model_data_list('cuda')
|
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
|
||||||
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
||||||
|
|
||||||
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda')
|
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
|
||||||
|
print('cuda_non_model_data_list ', cuda_non_model_data_list)
|
||||||
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
||||||
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user