mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] non model data tracing (#545)
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
col_model_data_mem_usage
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
@@ -51,7 +51,7 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, cpu_mem_use = col_model_data_mem_usage(model)
|
||||
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
|
||||
@@ -63,11 +63,13 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
||||
Reference in New Issue
Block a user