[zero] refactor model data tracing (#522)

This commit is contained in:
Jiarui Fang
2022-03-25 18:03:32 +08:00
committed by GitHub
parent 3601b2bad0
commit 8d8c5407c0
8 changed files with 128 additions and 28 deletions

View File

@@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class):
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
else:
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
GLOBAL_MODEL_DATA_TRACER.clear()
@@ -65,5 +67,4 @@ def test_zero_init_context(world_size):
if __name__ == '__main__':
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(4)