mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[hotfix] fix stm cuda model data size (#710)
This commit is contained in:
parent
140263a394
commit
715b86eadd
@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
|||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -48,14 +49,13 @@ class StatefulTensorMgr(object):
|
|||||||
# find stateful tensor in state COMPUTE
|
# find stateful tensor in state COMPUTE
|
||||||
move_to_cuda_tensor_list = []
|
move_to_cuda_tensor_list = []
|
||||||
cuda_demand = 0
|
cuda_demand = 0
|
||||||
used_cuda_model_data = 0
|
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
||||||
hold_cuda_tensor_list = []
|
hold_cuda_tensor_list = []
|
||||||
for tensor in self._stateful_tensor_list:
|
for tensor in self._stateful_tensor_list:
|
||||||
if tensor.state == TensorState.FREE:
|
if tensor.state == TensorState.FREE:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tensor.device.type == 'cuda':
|
if tensor.device.type == 'cuda':
|
||||||
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
|
||||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||||
hold_cuda_tensor_list.append(tensor)
|
hold_cuda_tensor_list.append(tensor)
|
||||||
elif tensor.device.type == 'cpu':
|
elif tensor.device.type == 'cpu':
|
||||||
|
Loading…
Reference in New Issue
Block a user