mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 18:43:58 +00:00
[Gemini] update API of the chunkmemstatscollector. (#2129)
This commit is contained in:
parent
2938edf446
commit
c89c66a858
@ -55,7 +55,7 @@ class GeminiManager:
|
|||||||
|
|
||||||
get the memory statistics during training.
|
get the memory statistics during training.
|
||||||
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||||
"""
|
"""
|
||||||
if self._premade_memstats_:
|
if self._premade_memstats_:
|
||||||
return self._memstats
|
return self._memstats
|
||||||
|
@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
|
|||||||
class ChunkMemStatsCollector(MemStatsCollector):
|
class ChunkMemStatsCollector(MemStatsCollector):
|
||||||
|
|
||||||
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Memory Statistic Collector for Chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_manager (ChunkManager): the chunk manager.
|
||||||
|
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
|
||||||
|
"""
|
||||||
super().__init__(memstats)
|
super().__init__(memstats)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def record_model_data_volume(self) -> None:
|
def record_model_data_volume(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""
|
||||||
|
record model data volumn on cuda and cpu.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
self._memstats.record_max_cuda_model_data(cuda_mem)
|
||||||
self._memstats.append_model_data('cuda', cuda_mem)
|
|
||||||
self._memstats.append_model_data('cpu', cpu_mem)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cuda_margin_mem(self) -> float:
|
def cuda_margin_mem(self) -> float:
|
||||||
|
@ -22,6 +22,7 @@ class MemStats(object):
|
|||||||
self._preop_step = 0
|
self._preop_step = 0
|
||||||
|
|
||||||
self._prev_overall_cuda = -1
|
self._prev_overall_cuda = -1
|
||||||
|
self._max_overall_cuda = 0
|
||||||
self._prev_md_cuda = -1
|
self._prev_md_cuda = -1
|
||||||
|
|
||||||
# old version
|
# old version
|
||||||
@ -46,6 +47,11 @@ class MemStats(object):
|
|||||||
|
|
||||||
def record_max_cuda_overall_data(self, val):
|
def record_max_cuda_overall_data(self, val):
|
||||||
self._prev_overall_cuda = val
|
self._prev_overall_cuda = val
|
||||||
|
self._max_overall_cuda = max(self._max_overall_cuda, val)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_overall_cuda(self):
|
||||||
|
return self._max_overall_cuda
|
||||||
|
|
||||||
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
||||||
"""
|
"""
|
||||||
@ -85,67 +91,6 @@ class MemStats(object):
|
|||||||
else:
|
else:
|
||||||
return self._param_runtime_order
|
return self._param_runtime_order
|
||||||
|
|
||||||
## APIs to be depracated
|
|
||||||
def append_overall_data(self, device_type: str, val: float):
|
|
||||||
if device_type == 'cuda':
|
|
||||||
self._overall_cuda_list.append(val)
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
self._overall_cpu_list.append(val)
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def append_model_data(self, device_type: str, val: float):
|
|
||||||
if device_type == 'cuda':
|
|
||||||
self._model_data_cuda_list.append(val)
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
self._model_data_cpu_list.append(val)
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def last_model_data(self, device_type: str):
|
|
||||||
if len(self._model_data_cuda_list) == 0:
|
|
||||||
return None
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._model_data_cuda_list[-1]
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._model_data_cpu_list[-1]
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def append_non_model_data(self, device_type: str, val=None):
|
|
||||||
if device_type == 'cuda':
|
|
||||||
if val is None:
|
|
||||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
|
||||||
return
|
|
||||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
|
||||||
else:
|
|
||||||
self._non_model_data_cuda_list.append(val)
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
if val is None:
|
|
||||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
|
||||||
return
|
|
||||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
|
||||||
else:
|
|
||||||
self._non_model_data_cuda_list.append(val)
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._overall_cuda_list
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._overall_cpu_list
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def model_data_list(self, device_type: str) -> List[int]:
|
|
||||||
if device_type == 'cuda':
|
|
||||||
return self._model_data_cuda_list
|
|
||||||
elif device_type == 'cpu':
|
|
||||||
return self._model_data_cpu_list
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return self._non_model_data_cuda_list
|
return self._non_model_data_cuda_list
|
||||||
|
@ -59,6 +59,7 @@ class MemStatsCollector:
|
|||||||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||||
|
|
||||||
def start_collection(self):
|
def start_collection(self):
|
||||||
|
print('start collection')
|
||||||
self._start_flag = True
|
self._start_flag = True
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
|
||||||
@ -68,31 +69,24 @@ class MemStatsCollector:
|
|||||||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._mem_monitor.finish()
|
self._mem_monitor.finish()
|
||||||
|
print(f'finish_collection {self._step_total}')
|
||||||
|
|
||||||
|
# deprecated
|
||||||
def record_model_data_volume(self) -> None:
|
def record_model_data_volume(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume")
|
||||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
|
||||||
self._memstats.append_model_data('cuda', cuda_mem)
|
|
||||||
self._memstats.append_model_data('cpu', cpu_mem)
|
|
||||||
|
|
||||||
def sample_overall_data(self) -> None:
|
def sample_overall_data(self) -> None:
|
||||||
"""Sampling non model data statistics.
|
"""
|
||||||
|
Sampling overall and non model data cuda memory statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
# overall data recording is after model data recording
|
cuda_overall = self._mem_monitor.finish()
|
||||||
if len(self._memstats._model_data_cuda_list) == 0:
|
self._memstats.record_max_cuda_overall_data(cuda_overall)
|
||||||
return
|
self._memstats.calc_max_cuda_non_model_data()
|
||||||
|
|
||||||
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
|
||||||
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
|
||||||
|
|
||||||
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
|
||||||
|
|
||||||
self._memstats.append_non_model_data('cuda')
|
|
||||||
self._memstats.append_non_model_data('cpu')
|
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
|
@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
|
|||||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||||
f.write('CUDA model data (GB)\n')
|
f.write('CUDA model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
f.write('CUDA non model data (GB)\n')
|
f.write('CUDA non model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||||
@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
|
|||||||
# the way to calculate margin space is based on the assumption that
|
# the way to calculate margin space is based on the assumption that
|
||||||
# model data is fixed in cuda during training.
|
# model data is fixed in cuda during training.
|
||||||
# cuda margin space can be used to store OS.
|
# cuda margin space can be used to store OS.
|
||||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
self._cuda_margin_space = colo_device_memory_capacity(
|
||||||
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _post_backward_operations(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
|
@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||||||
self._gemini_manager.adjust_layout(chunks)
|
self._gemini_manager.adjust_layout(chunks)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
|
|
||||||
|
# record cuda model data of the current OP
|
||||||
self._gemini_manager.record_model_data_volume()
|
self._gemini_manager.record_model_data_volume()
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
|
@ -57,11 +57,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||||||
|
|
||||||
if model_name == 'repeated_computed_layers':
|
if model_name == 'repeated_computed_layers':
|
||||||
for idx, p in enumerate(model.parameters()):
|
for idx, p in enumerate(model.parameters()):
|
||||||
step_list = memstats.param_used_timestep(p)
|
step_list = memstats.param_used_step(p)
|
||||||
if idx < 4:
|
if idx < 4:
|
||||||
assert len(step_list) == 4
|
assert len(step_list) == 4
|
||||||
|
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
|
||||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
|
||||||
|
|
||||||
|
|
||||||
class MyTestModel(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.proj1 = nn.Linear(512, 512)
|
|
||||||
self.weight = nn.Parameter(torch.randn(1024, 512))
|
|
||||||
self.proj2 = nn.Linear(1024, 512)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj1(x)
|
|
||||||
x = F.linear(x, self.weight)
|
|
||||||
x = self.proj2(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def run_mem_collector_testing():
|
|
||||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
|
||||||
fraction = (50 * 1024**2) / cuda_capacity
|
|
||||||
# limit max memory to 50MB
|
|
||||||
colo_set_process_memory_fraction(fraction)
|
|
||||||
shard_strategy = BucketTensorShardStrategy()
|
|
||||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
|
||||||
model = MyTestModel()
|
|
||||||
|
|
||||||
model = ShardedModelV2(module=model,
|
|
||||||
shard_strategy=shard_strategy,
|
|
||||||
reduce_scatter_bucket_size_mb=1,
|
|
||||||
tensor_placement_policy='auto')
|
|
||||||
|
|
||||||
data = torch.randn(2, 512, device=get_current_device())
|
|
||||||
|
|
||||||
output = model(data)
|
|
||||||
loss = torch.mean(output)
|
|
||||||
model.backward(loss)
|
|
||||||
|
|
||||||
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
|
|
||||||
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
|
||||||
|
|
||||||
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
|
|
||||||
print('cuda_non_model_data_list ', cuda_non_model_data_list)
|
|
||||||
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
|
||||||
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_mem_collector_testing()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_mem_collector(world_size=2):
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_mem_collector()
|
|
Loading…
Reference in New Issue
Block a user