mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManager
This commit is contained in:
73
tests/test_gemini/test_gemini_manager.py
Normal file
73
tests/test_gemini/test_gemini_manager.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_gemini_manager():
|
||||
# reset the manager, in case that there exists memory information left
|
||||
manager = StatefulTensor.GST_MGR
|
||||
manager.reset()
|
||||
|
||||
# occupation 8
|
||||
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||
# occupation 60
|
||||
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||
|
||||
# occupation 28
|
||||
t1 = torch.empty(7, device='cuda')
|
||||
# occupation 12
|
||||
t2 = torch.empty(3, device='cpu')
|
||||
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||
st4 = StatefulTensor(None, TensorState.FREE)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 60
|
||||
assert manager.total_mem['cuda'] == 36
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||
|
||||
st4.payload_reset(t2)
|
||||
st3.payload_reset(t2)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 84
|
||||
assert manager.total_mem['cuda'] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
|
||||
st1.move_to(torch.device('cpu'))
|
||||
st2.move_to(torch.device('cpu'))
|
||||
st3.move_to(torch.device('cuda', 0))
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
|
||||
st1.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_manager()
|
@@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
import colossalai
|
||||
|
||||
import torch
|
||||
|
@@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.test_zero.common import CONFIG, allclose
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
|
@@ -2,9 +2,10 @@ import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
|
Reference in New Issue
Block a user