mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[refactor] refactor the memory utils (#715)
This commit is contained in:
@@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Dict, List
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
@@ -64,7 +64,7 @@ class StatefulTensorMgr(object):
|
||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||
else:
|
||||
raise RuntimeError
|
||||
cuda_capacity = colo_cuda_memory_capacity()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
|
||||
if self._warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
|
@@ -33,7 +33,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
|
@@ -16,7 +16,7 @@ from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
@@ -231,7 +231,7 @@ class ShardedModelV2(nn.Module):
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||
self._memstats_collector.overall_mem_stats('cuda'))
|
||||
|
||||
@torch.no_grad()
|
||||
|
Reference in New Issue
Block a user