[refactor] refactor the memory utils (#715)

This commit is contained in:
Jiarui Fang
2022-04-11 16:47:57 +08:00
committed by GitHub
parent dbd96fe90a
commit 193dc8dacb
20 changed files with 218 additions and 308 deletions

View File

@@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Dict, List
from colossalai.utils.memory_tracer import MemStatsCollector
@@ -64,7 +64,7 @@ class StatefulTensorMgr(object):
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
else:
raise RuntimeError
cuda_capacity = colo_cuda_memory_capacity()
cuda_capacity = colo_device_memory_capacity(get_current_device())
if self._warmup:
# We designate a part of CUDA memory for model data in warmup iterations.

View File

@@ -33,7 +33,7 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded:
return
if t.payload.device.type == 'cuda':
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)

View File

@@ -16,7 +16,7 @@ from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@@ -231,7 +231,7 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda'))
@torch.no_grad()