mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)
* polish chunk manager * polish unit test * impl add_extern_static_tensor for chunk mgr * add mem stats collector v2 * polish code * polish unit test * polish code * polish get chunks
This commit is contained in:
@@ -32,7 +32,7 @@ HAS_TENSORS = {
|
||||
}
|
||||
}
|
||||
|
||||
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
|
||||
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
|
||||
rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
if rank == 0:
|
||||
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
|
||||
params = [torch.rand(32, 32) for _ in range(3)]
|
||||
chunk_size = 2048 if use_chunk else None
|
||||
params = [torch.rand(8, 8) for _ in range(3)]
|
||||
chunk_size = 128 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
@@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
||||
for p in params:
|
||||
chunk_manager.access_chunk(p)
|
||||
chunks = chunk_manager.get_chunks(params)
|
||||
for chunk in chunks:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
check_has_params(params, [True, True, True])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
||||
for p in params:
|
||||
chunk_manager.release_chunk(p)
|
||||
for chunk in chunks:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
for p in params:
|
||||
chunk_manager.move_chunk(p, torch.device('cpu'))
|
||||
for chunk in chunks:
|
||||
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
|
Reference in New Issue
Block a user