[polish] use GLOBAL_MODEL_DATA_TRACER (#417)

This commit is contained in:
Jiarui Fang
2022-03-15 11:29:46 +08:00
committed by GitHub
parent 23ba3fc450
commit 56bb412e72
8 changed files with 25 additions and 25 deletions

View File

@@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port, init_device, shard_strategy):
@@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda':
assert (ModelDataTracer().cuda_usage > 0)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist