[zero] global model data memory tracer (#360)

This commit is contained in:
Jiarui Fang
2022-03-10 11:20:04 +08:00
committed by Frank Lee
parent cb34cd384d
commit ea2872073f
5 changed files with 94 additions and 4 deletions

View File

@@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, Net
from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port):
@@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
@pytest.mark.parametrize("world_size", [1, 4])
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)