diff --git a/colossalai/zero/utils/tensor_placement_policy.py b/colossalai/zero/utils/tensor_placement_policy.py index bd962cf38..e3f3dff3d 100644 --- a/colossalai/zero/utils/tensor_placement_policy.py +++ b/colossalai/zero/utils/tensor_placement_policy.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Dict +from typing import List, Optional import torch from colossalai.utils import get_current_device from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage @@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) to_free_tensor_list = [t for (t, idx) in next_compute_idx] for t in to_free_tensor_list: - if freed_cuda_model_data > to_free_cuda_model_data: + if freed_cuda_model_data >= to_free_cuda_model_data: break freed_cuda_model_data += colo_tensor_mem_usage(t)[0] colo_model_data_tensor_move_inline(t, torch.device('cpu')) diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index 6edefd38f..15ca1cc5c 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction +from colossalai.utils.memory import colo_set_process_memory_fraction from colossalai.zero.utils import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import TensorState @@ -21,18 +21,22 @@ class Net(torch.nn.Module): def __init__(self) -> None: super().__init__() - # each parameter is 512 MB - self.p0 = Parameter(torch.empty(1024, 1024, 128)) - self.p1 = Parameter(torch.empty(1024, 1024, 128)) - self.p2 = Parameter(torch.empty(1024, 1024, 128)) + # each parameter is 128 MB + self.p0 = Parameter(torch.empty(1024, 1024, 32)) + self.p1 = Parameter(torch.empty(1024, 1024, 32)) + self.p2 = Parameter(torch.empty(1024, 1024, 32)) + + +def limit_cuda_memory(memory_in_g: float): + cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory + fraction = (memory_in_g * 1024**3) / cuda_capacity + colo_set_process_memory_fraction(fraction) def run_stm(): - cuda_capacity = colo_device_memory_capacity(get_current_device()) - fraction = (1.4 * 1024**3) / cuda_capacity - # limit max memory to 1.4GB - # which means only 2 parameters can be on CUDA - colo_set_process_memory_fraction(fraction) + # warmup phase use 20% CUDA memory to store params + # only 2 params can be on CUDA + limit_cuda_memory(1.26) model = Net() for p in model.parameters(): p.colo_attr = ShardedParamV2(p, set_data_none=True) @@ -65,6 +69,8 @@ def run_stm(): stateful_tensor_mgr.reset() # warmup done + # only 2 params can be on CUDA + limit_cuda_memory(0.26) # use OPT-like eviction strategy apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr) mem_collector.sample_model_data() @@ -112,7 +118,7 @@ def run_dist(rank, world_size, port): run_stm() -@pytest.mark.skip +@pytest.mark.gpu @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_stateful_tensor_manager(world_size=1): run_func = partial(run_dist, world_size=world_size, port=free_port()) @@ -120,4 +126,5 @@ def test_stateful_tensor_manager(world_size=1): if __name__ == '__main__': + # this unit test can pass if available CUDA memory >= 1.5G test_stateful_tensor_manager()