diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/shard_utils/stateful_tensor_mgr.py index 63a89fcdc..877c0763c 100644 --- a/colossalai/zero/shard_utils/stateful_tensor_mgr.py +++ b/colossalai/zero/shard_utils/stateful_tensor_mgr.py @@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from typing import Dict, List from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.logging import get_dist_logger @@ -48,14 +49,13 @@ class StatefulTensorMgr(object): # find stateful tensor in state COMPUTE move_to_cuda_tensor_list = [] cuda_demand = 0 - used_cuda_model_data = 0 + used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage hold_cuda_tensor_list = [] for tensor in self._stateful_tensor_list: if tensor.state == TensorState.FREE: continue if tensor.device.type == 'cuda': - used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0] if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) elif tensor.device.type == 'cpu':