mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[polish] use GLOBAL_MODEL_DATA_TRACER (#417)
This commit is contained in:
@@ -26,15 +26,15 @@ def run_naive_amp():
|
||||
test_models = ['repeated_computed_layers', 'nested_model']
|
||||
for test_name in test_models:
|
||||
get_component_func = non_distributed_component_funcs.get_callable(test_name)
|
||||
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
|
||||
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
|
||||
|
||||
# create model
|
||||
amp_model = model_builder(checkpoint=True).cuda()
|
||||
torch_model = copy.deepcopy(amp_model)
|
||||
|
||||
# create optimizer
|
||||
amp_optimizer = optim_builder(amp_model)
|
||||
torch_optimizer = optim_builder(torch_model)
|
||||
amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3)
|
||||
torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
# inject naive amp
|
||||
amp_config = dict(initial_scale=1)
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
@@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
if init_device.type == 'cuda':
|
||||
assert (ModelDataTracer().cuda_usage > 0)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user