From 039b7ed3bc33173e36c5c4decd41f8d7b1ec0f45 Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 10 Aug 2022 16:40:29 +0800 Subject: [PATCH] [polish] add update directory in gemini; rename AgChunk to ChunkV2 (#1432) --- colossalai/gemini/update/__init__.py | 1 + colossalai/gemini/{ag_chunk.py => update/chunkv2.py} | 2 +- .../{chunk/test_agchunk.py => update/test_chunkv2.py} | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 colossalai/gemini/update/__init__.py rename colossalai/gemini/{ag_chunk.py => update/chunkv2.py} (97%) rename tests/test_gemini/{chunk/test_agchunk.py => update/test_chunkv2.py} (95%) diff --git a/colossalai/gemini/update/__init__.py b/colossalai/gemini/update/__init__.py new file mode 100644 index 000000000..cc181cfa4 --- /dev/null +++ b/colossalai/gemini/update/__init__.py @@ -0,0 +1 @@ +from .chunkv2 import ChunkV2 diff --git a/colossalai/gemini/ag_chunk.py b/colossalai/gemini/update/chunkv2.py similarity index 97% rename from colossalai/gemini/ag_chunk.py rename to colossalai/gemini/update/chunkv2.py index cdeb78222..b1bbfce5e 100644 --- a/colossalai/gemini/ag_chunk.py +++ b/colossalai/gemini/update/chunkv2.py @@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF free_storage, alloc_storage -class AgChunk: +class ChunkV2: def __init__(self, chunk_size: int, process_group: ColoProcessGroup, diff --git a/tests/test_gemini/chunk/test_agchunk.py b/tests/test_gemini/update/test_chunkv2.py similarity index 95% rename from tests/test_gemini/chunk/test_agchunk.py rename to tests/test_gemini/update/test_chunkv2.py index 005c6503b..deea46acb 100644 --- a/tests/test_gemini/chunk/test_agchunk.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ColoParameter from colossalai.gemini import TensorState -from colossalai.gemini.ag_chunk import AgChunk +from colossalai.gemini.update import ChunkV2 def dist_sum(x): @@ -38,7 +38,7 @@ def check_euqal(param, param_cp): def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = ColoProcessGroup() - my_chunk = AgChunk( + my_chunk = ChunkV2( chunk_size=1024, process_group=pg, dtype=torch.float32,