[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)

This commit is contained in:
Jiarui Fang 2022-12-06 16:43:06 +08:00 committed by GitHub
parent 1f99205827
commit 33f4412102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 139 deletions

View File

@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
super().__init__() super().__init__()
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
# override
def sample_model_data(self) -> None: def sample_model_data(self) -> None:
"""Sampling model data statistics. """Sampling model data statistics.
""" """
if self._start_flag: if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda'] cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu'] cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._memstats.append_model_data('cuda', cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._memstats.append_model_data('cpu', cpu_mem)
@property @property
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')

View File

@ -0,0 +1,94 @@
from typing import Any, Dict, List
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
self.param_non_model_data_map: Dict(Any, List[int]) = {}
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
elif device_type == 'cpu':
self._overall_cpu_list.append(val)
else:
raise TypeError
def append_model_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._model_data_cuda_list.append(val)
elif device_type == 'cpu':
self._model_data_cpu_list.append(val)
else:
raise TypeError
def append_non_model_data(self, device_type: str):
if device_type == 'cuda':
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
elif device_type == 'cpu':
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
raise TypeError
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []

View File

@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from .memory_stats import MemStats
class MemStatsCollector: class MemStatsCollector:
""" """
@ -22,43 +24,12 @@ class MemStatsCollector:
def __init__(self) -> None: def __init__(self) -> None:
self._mem_monitor = SyncCudaMemoryMonitor() self._mem_monitor = SyncCudaMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = [] self._sampling_time = []
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0
self._memstats = MemStats()
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def next_period_non_model_data_usage(self, device_type: str) -> int: def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period """Get max non model data memory usage of current sampling period
@ -71,7 +42,7 @@ class MemStatsCollector:
""" """
assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx] next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data return next_non_model_data
@ -95,37 +66,29 @@ class MemStatsCollector:
if self._start_flag: if self._start_flag:
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._memstats.append_model_data('cuda', cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None: def sample_overall_data(self) -> None:
"""Sampling non model data statistics. """Sampling non model data statistics.
""" """
if self._start_flag: if self._start_flag:
# overall data recording is after model data recording # overall data recording is after model data recording
if len(self._model_data_cuda_list) == 0: if len(self._memstats._model_data_cuda_list) == 0:
return return
self._overall_cuda_list.append(self._mem_monitor.finish()) self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu'))) self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list) assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) self._memstats.append_non_model_data('cuda')
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._memstats.append_non_model_data('cpu')
self._sampling_time.append(time.time()) self._sampling_time.append(time.time())
self._mem_monitor.start() self._mem_monitor.start()
def clear(self) -> None: def clear(self) -> None:
self._model_data_cuda_list = [] self._memstats.clear()
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0

View File

@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda', tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False, reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args, *args,
**kwargs): **kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@ -119,13 +118,9 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
if self.user_static_memstats:
self._memstats_collector = StaticMemStatsCollector(self.module)
else:
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n') f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n') f.write('\n')
f.write('CUDA non model data (GB)\n') f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB'))) f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n') f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n') f.write('\n')
def _pre_forward_operations(self, *args): def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook # the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector: if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats() self._start_collect_memstats()
for p in self.module.parameters(): for p in self.module.parameters():
@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max( self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda')) self._memstats_collector._memstats.overall_mem_stats('cuda'))
@torch.no_grad() @torch.no_grad()
def _post_backward_operations(self) -> None: def _post_backward_operations(self) -> None:

View File

@ -1,17 +1,19 @@
import torch from functools import partial
import colossalai
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
class MyTestModel(torch.nn.Module): class MyTestModel(torch.nn.Module):
@ -50,10 +52,11 @@ def run_mem_collector_testing():
loss = torch.mean(output) loss = torch.mean(output)
model.backward(loss) model.backward(loss)
cuda_model_data_list = model._memstats_collector.model_data_list('cuda') cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032] assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda') cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
print('cuda_non_model_data_list ', cuda_non_model_data_list)
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1] assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]