[refactor] refactor the memory utils (#715)

This commit is contained in:
Jiarui Fang
2022-04-11 16:47:57 +08:00
committed by GitHub
parent dbd96fe90a
commit 193dc8dacb
20 changed files with 218 additions and 308 deletions

View File

@@ -13,7 +13,7 @@ from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -29,7 +29,7 @@ def run_model_test(init_device_type, shard_strategy_class):
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:

View File

@@ -57,10 +57,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True):
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,

View File

@@ -2,9 +2,10 @@ import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.shard_utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
@@ -26,7 +27,7 @@ class Net(torch.nn.Module):
def run_stm():
cuda_capacity = colo_cuda_memory_capacity()
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (1.4 * 1024**3) / cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA