mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-13 06:11:09 +00:00
[zero] refactor model data tracing (#522)
This commit is contained in:
@@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
else:
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||
|
||||
|
||||
@@ -65,5 +67,4 @@ def test_zero_init_context(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(4)
|
||||
|
||||
Reference in New Issue
Block a user