diff --git a/colossalai/gemini/update/__init__.py b/colossalai/gemini/update/__init__.py new file mode 100644 index 000000000..cc181cfa4 --- /dev/null +++ b/colossalai/gemini/update/__init__.py @@ -0,0 +1 @@ +from .chunkv2 import ChunkV2 diff --git a/colossalai/gemini/ag_chunk.py b/colossalai/gemini/update/chunkv2.py similarity index 97% rename from colossalai/gemini/ag_chunk.py rename to colossalai/gemini/update/chunkv2.py index cdeb78222..b1bbfce5e 100644 --- a/colossalai/gemini/ag_chunk.py +++ b/colossalai/gemini/update/chunkv2.py @@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF free_storage, alloc_storage -class AgChunk: +class ChunkV2: def __init__(self, chunk_size: int, process_group: ColoProcessGroup, diff --git a/tests/test_gemini/chunk/test_agchunk.py b/tests/test_gemini/update/test_chunkv2.py similarity index 95% rename from tests/test_gemini/chunk/test_agchunk.py rename to tests/test_gemini/update/test_chunkv2.py index 005c6503b..deea46acb 100644 --- a/tests/test_gemini/chunk/test_agchunk.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ColoParameter from colossalai.gemini import TensorState -from colossalai.gemini.ag_chunk import AgChunk +from colossalai.gemini.update import ChunkV2 def dist_sum(x): @@ -38,7 +38,7 @@ def check_euqal(param, param_cp): def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = ColoProcessGroup() - my_chunk = AgChunk( + my_chunk = ChunkV2( chunk_size=1024, process_group=pg, dtype=torch.float32,