mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-18 16:46:08 +00:00
[refactor] refactor the memory utils (#715)
This commit is contained in:
@@ -13,7 +13,7 @@ from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -29,7 +29,7 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
if init_device_type == 'cuda':
|
||||
init_device = torch.device(f"cuda:{get_current_device()}")
|
||||
init_device = get_current_device()
|
||||
elif init_device_type == 'cpu':
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
|
||||
@@ -57,10 +57,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
|
||||
@@ -2,9 +2,10 @@ import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.shard_utils import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
@@ -26,7 +27,7 @@ class Net(torch.nn.Module):
|
||||
|
||||
|
||||
def run_stm():
|
||||
cuda_capacity = colo_cuda_memory_capacity()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
fraction = (1.4 * 1024**3) / cuda_capacity
|
||||
# limit max memory to 1.4GB
|
||||
# which means only 2 parameters can be on CUDA
|
||||
|
||||
Reference in New Issue
Block a user