diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 299ea0518..c4a4f245a 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -361,10 +361,11 @@ class Chunk: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None + maybe_work = None if not self.is_gathered: - return self.__gather(async_op=async_access) + maybe_work = self.__gather(async_op=async_access) self.__update_tensors_ptr() - return None + return maybe_work def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index cab26c822..736238a09 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,7 +5,6 @@ from typing import List import torch -from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState @@ -17,9 +16,6 @@ class TrainingPhase(Enum): BACKWARD = 1 -logger = DistributedLogger("gemini_hook") - - class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 0a8e0ae4a..83e475575 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -177,6 +177,10 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def compute_list(self) -> List[Tuple[Chunk, ...]]: return self._compute_list @@ -189,10 +193,6 @@ class GeminiManager: def async_works(self) -> Dict[Chunk, dist.Work]: return self._async_works - @property - def placement_policy(self) -> PlacementPolicy: - return self._placement_policy - @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 570a0aa42..478ace3d4 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) +@parameterize("max_prefetch", [0, 1, 4]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, master_weights: bool = True, + max_prefetch: int = 0, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd( config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP( - model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + model, + config_dict, + init_device, + pin_memory=True, + **placement_config, + master_weights=master_weights, + max_prefetch=max_prefetch, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index fd0e9fd7c..11d29c50f 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("max_prefetch", [0, 1, 4]) def exam_gemini_grad_acc( - placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool + placement_config, + keep_gathered: bool, + model_name: str, + master_weights: bool, + use_grad_checkpoint: bool, + max_prefetch: int, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -81,6 +87,7 @@ def exam_gemini_grad_acc( pin_memory=True, enable_gradient_accumulation=True, master_weights=master_weights, + max_prefetch=max_prefetch, **placement_config, ) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 0a9bac092..ad6dc2f78 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) -def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, max_prefetch: int): set_seed(1912) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): chunk_init_device=init_device, pin_memory=True, master_weights=master_weights, + max_prefetch=max_prefetch, **placement_config, ) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..eab55f190 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -71,7 +71,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_model_step( + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, max_prefetch: int +): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -94,7 +97,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False model = GeminiDDP( - model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + model, + config_dict, + **placement_config, + mixed_precision=mixed_precision, + master_weights=master_weights, + max_prefetch=max_prefetch, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 23e2d8083..3cbd36917 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -28,7 +28,8 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) -def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -44,7 +45,14 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) + model = GeminiDDP( + model, + config_dict, + **placement_config, + pin_memory=True, + master_weights=master_weights, + max_prefetch=max_prefetch, + ) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 8d70ae3b1..a721c96a1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -20,7 +20,8 @@ PLACEMENT_CONFIGS = [ @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -def exam_zero_optim_state_dict(placement_config, keep_gathered): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next( iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) @@ -35,7 +36,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch) optimizer = HybridAdam(model.parameters()) optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32