mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[zero] global model data memory tracer (#360)
This commit is contained in:
@@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG, Net
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
|
||||
assert param.col_attr.data.is_sharded
|
||||
assert param.col_attr.data.payload.device.type == 'cuda'
|
||||
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_zero_init_context(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
Reference in New Issue
Block a user