mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[zero] memtracer to record cuda memory usage of model data and overall system (#395)
This commit is contained in:
@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
|
||||
|
||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
|
@@ -9,12 +9,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
@@ -37,13 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
elif init_device.type == 'cpu':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
||||
assert (ModelDataTracer().cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -60,5 +57,5 @@ def test_zero_init_context(world_size, init_device, shard_strategy):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'), TensorShardStrategy)
|
||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
|
||||
|
@@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||
@@ -33,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
||||
|
||||
if use_zero_init_ctx:
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=torch.device('cpu'),
|
||||
target_device=torch.device(f'cpu:0'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
@@ -59,6 +60,9 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
||||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
|
||||
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
|
@@ -1,6 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
@@ -82,4 +79,4 @@ def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
|
||||
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
|
Reference in New Issue
Block a user