[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)

This commit is contained in:
Jiarui Fang 2022-12-06 16:43:06 +08:00 committed by GitHub
parent 1f99205827
commit 33f4412102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 139 deletions

View File

@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
super().__init__() super().__init__()
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
# override
def sample_model_data(self) -> None: def sample_model_data(self) -> None:
"""Sampling model data statistics. """Sampling model data statistics.
""" """
if self._start_flag: if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda'] cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu'] cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._memstats.append_model_data('cuda', cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._memstats.append_model_data('cpu', cpu_mem)
@property @property
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')

View File

@ -0,0 +1,94 @@
from typing import Any, Dict, List
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
self.param_non_model_data_map: Dict(Any, List[int]) = {}
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
elif device_type == 'cpu':
self._overall_cpu_list.append(val)
else:
raise TypeError
def append_model_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._model_data_cuda_list.append(val)
elif device_type == 'cpu':
self._model_data_cpu_list.append(val)
else:
raise TypeError
def append_non_model_data(self, device_type: str):
if device_type == 'cuda':
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
elif device_type == 'cpu':
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
raise TypeError
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []

View File

@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from .memory_stats import MemStats
class MemStatsCollector: class MemStatsCollector:
""" """
@ -22,43 +24,12 @@ class MemStatsCollector:
def __init__(self) -> None: def __init__(self) -> None:
self._mem_monitor = SyncCudaMemoryMonitor() self._mem_monitor = SyncCudaMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = [] self._sampling_time = []
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0
self._memstats = MemStats()
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def next_period_non_model_data_usage(self, device_type: str) -> int: def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period """Get max non model data memory usage of current sampling period
@ -71,7 +42,7 @@ class MemStatsCollector:
""" """
assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx] next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data return next_non_model_data
@ -95,37 +66,29 @@ class MemStatsCollector:
if self._start_flag: if self._start_flag:
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._memstats.append_model_data('cuda', cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None: def sample_overall_data(self) -> None:
"""Sampling non model data statistics. """Sampling non model data statistics.
""" """
if self._start_flag: if self._start_flag:
# overall data recording is after model data recording # overall data recording is after model data recording
if len(self._model_data_cuda_list) == 0: if len(self._memstats._model_data_cuda_list) == 0:
return return
self._overall_cuda_list.append(self._mem_monitor.finish()) self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu'))) self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list) assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) self._memstats.append_non_model_data('cuda')
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._memstats.append_non_model_data('cpu')
self._sampling_time.append(time.time()) self._sampling_time.append(time.time())
self._mem_monitor.start() self._mem_monitor.start()
def clear(self) -> None: def clear(self) -> None:
self._model_data_cuda_list = [] self._memstats.clear()
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0

View File

@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda', tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False, reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args, *args,
**kwargs): **kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
if self.user_static_memstats: self._memstats_collector = MemStatsCollector()
self._memstats_collector = StaticMemStatsCollector(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else: else:
@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n') f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n') f.write('\n')
f.write('CUDA non model data (GB)\n') f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB'))) f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n') f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n') f.write('\n')
def _pre_forward_operations(self, *args): def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook # the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector: if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats() self._start_collect_memstats()
for p in self.module.parameters(): for p in self.module.parameters():
@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max( self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda')) self._memstats_collector._memstats.overall_mem_stats('cuda'))
@torch.no_grad() @torch.no_grad()
def _post_backward_operations(self) -> None: def _post_backward_operations(self) -> None:

View File

@ -1,74 +1,77 @@
import torch from functools import partial
import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch
import torch.nn as nn import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn as nn
from colossalai.utils.cuda import get_current_device import torch.nn.functional as F
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext import colossalai
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.utils import free_port
from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device
from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from functools import partial from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
class MyTestModel(torch.nn.Module):
def __init__(self) -> None: class MyTestModel(torch.nn.Module):
super().__init__()
self.proj1 = nn.Linear(512, 512) def __init__(self) -> None:
self.weight = nn.Parameter(torch.randn(1024, 512)) super().__init__()
self.proj2 = nn.Linear(1024, 512) self.proj1 = nn.Linear(512, 512)
self.weight = nn.Parameter(torch.randn(1024, 512))
def forward(self, x): self.proj2 = nn.Linear(1024, 512)
x = self.proj1(x)
x = F.linear(x, self.weight) def forward(self, x):
x = self.proj2(x) x = self.proj1(x)
x = F.linear(x, self.weight)
return x x = self.proj2(x)
return x
def run_mem_collector_testing():
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (50 * 1024**2) / cuda_capacity def run_mem_collector_testing():
# limit max memory to 50MB cuda_capacity = colo_device_memory_capacity(get_current_device())
colo_set_process_memory_fraction(fraction) fraction = (50 * 1024**2) / cuda_capacity
shard_strategy = BucketTensorShardStrategy() # limit max memory to 50MB
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): colo_set_process_memory_fraction(fraction)
model = MyTestModel() shard_strategy = BucketTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = ShardedModelV2(module=model, model = MyTestModel()
shard_strategy=shard_strategy,
reduce_scatter_bucket_size_mb=1, model = ShardedModelV2(module=model,
tensor_placement_policy='auto') shard_strategy=shard_strategy,
reduce_scatter_bucket_size_mb=1,
data = torch.randn(2, 512, device=get_current_device()) tensor_placement_policy='auto')
output = model(data) data = torch.randn(2, 512, device=get_current_device())
loss = torch.mean(output)
model.backward(loss) output = model(data)
loss = torch.mean(output)
cuda_model_data_list = model._memstats_collector.model_data_list('cuda') model.backward(loss)
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda') assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
print('cuda_non_model_data_list ', cuda_non_model_data_list)
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
def run_dist(rank, world_size, port): assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_mem_collector_testing()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@pytest.mark.dist run_mem_collector_testing()
@rerun_if_address_is_in_use()
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port()) @pytest.mark.dist
mp.spawn(run_func, nprocs=world_size) @rerun_if_address_is_in_use()
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port())
if __name__ == '__main__': mp.spawn(run_func, nprocs=world_size)
test_mem_collector()
if __name__ == '__main__':
test_mem_collector()