mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[refactor] moving memtracer to gemini (#801)
This commit is contained in:
@@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
class MyTestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -37,7 +37,7 @@ def run_mem_collector_testing():
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = TestModel()
|
||||
model = MyTestModel()
|
||||
|
||||
model = ShardedModelV2(module=model,
|
||||
shard_strategy=shard_strategy,
|
||||
|
@@ -1,122 +0,0 @@
|
||||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# each parameter is 128 MB
|
||||
self.p0 = Parameter(torch.empty(1024, 1024, 32))
|
||||
self.p1 = Parameter(torch.empty(1024, 1024, 32))
|
||||
self.p2 = Parameter(torch.empty(1024, 1024, 32))
|
||||
|
||||
|
||||
def limit_cuda_memory(memory_in_g: float):
|
||||
cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory
|
||||
fraction = (memory_in_g * 1024**3) / cuda_capacity
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
|
||||
|
||||
def run_stm():
|
||||
# warmup phase use 20% CUDA memory to store params
|
||||
# only 2 params can be on CUDA
|
||||
limit_cuda_memory(1.26)
|
||||
model = Net()
|
||||
for p in model.parameters():
|
||||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
|
||||
mem_collector.start_collection()
|
||||
# Compute order: 0 1 2 0 1
|
||||
# warmup
|
||||
# use naive eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.sample_overall_data()
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.sample_overall_data()
|
||||
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.sample_overall_data()
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.sample_overall_data()
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.finish_collection()
|
||||
stateful_tensor_mgr.reset()
|
||||
|
||||
# warmup done
|
||||
# only 2 params can be on CUDA
|
||||
limit_cuda_memory(0.26 / tensor_placement_policy._steady_cuda_cap_ratio)
|
||||
# use OPT-like eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
|
||||
|
||||
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
||||
stateful_tensor_mgr: StatefulTensorMgr):
|
||||
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
for p in model.parameters():
|
||||
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
|
||||
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
stateful_tensor_mgr.adjust_layout()
|
||||
print_stats(model)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
|
||||
for n, p in model.named_parameters():
|
||||
if hash(p) in cuda_param_after_adjust:
|
||||
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
|
||||
else:
|
||||
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
|
||||
|
||||
|
||||
def print_stats(model: torch.nn.Module):
|
||||
msgs = []
|
||||
for n, p in model.named_parameters():
|
||||
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
|
||||
print(f'[ {", ".join(msgs)} ]')
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_stm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_stateful_tensor_manager(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# this unit test can pass if available CUDA memory >= 1.5G
|
||||
test_stateful_tensor_manager()
|
@@ -91,8 +91,6 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
|
||||
|
||||
# FIXME: enable this test in next PR
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
|
Reference in New Issue
Block a user