diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a5..c8cf3ec21 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ class OptimizerWrapper: """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961..d2754cbd9 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -373,7 +373,7 @@ class GeminiDDP(ModelWrapper): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a4976..ccd4634b5 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,14 @@ class GeminiOptimizer(OptimizerWrapper): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaa..9cc44c753 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -408,7 +408,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -427,14 +427,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ba6cafe6b..384ed6490 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -19,6 +19,8 @@ from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import assert_loose_close @@ -751,12 +753,13 @@ def run_with_hybridplugin(test_config): "config", [ (0, 1, 4, 1, 1), - # (0, 2, 2, 1, 1), - # (0, 2, 1, 2, 1), - # (0, 2, 1, 1, 2), + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): + test_config = config stage, ep_size, pp_size, tp_size, sp_size = config num_microbatches = pp_size dist.get_world_size() @@ -865,8 +868,15 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): ) # stage 0 chunk 0 parallel_output = None - if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) else: # for test without pp @@ -874,7 +884,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - # dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.dp_group) # =================================================================================== # run normal model with all dp(different) inputs @@ -891,8 +901,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + print(f"rank {dist.get_rank()} config {test_config} test passed") + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() def run_dist(rank, world_size, port):