mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 13:42:05 +00:00
[feat] update optimizer bwd; ä¸
This commit is contained in:
parent
d63479553c
commit
5c8bbf63a8
@ -49,11 +49,11 @@ class OptimizerWrapper:
|
|||||||
"""
|
"""
|
||||||
self.optim.zero_grad(*args, **kwargs)
|
self.optim.zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass on the loss.
|
Performs a backward pass on the loss.
|
||||||
"""
|
"""
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -373,7 +373,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
|
||||||
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -298,12 +298,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||||
self.module.backward(loss)
|
self.module.backward(loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
def backward_by_grad(
|
||||||
|
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
|
||||||
|
):
|
||||||
# This function is called except the last stage of pipeline parallel
|
# This function is called except the last stage of pipeline parallel
|
||||||
# It receives the scaled grad from the previous rank
|
# It receives the scaled grad from the previous rank
|
||||||
# No need to scale the grad again
|
# No need to scale the grad again
|
||||||
# Need to unscale when optimizing
|
# Need to unscale when optimizing
|
||||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
self.module.backward_by_grad(tensor, grad)
|
self.module.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
def _maybe_move_fp32_params(self):
|
def _maybe_move_fp32_params(self):
|
||||||
|
@ -408,7 +408,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
# torch.optim.Optimizer methods
|
# torch.optim.Optimizer methods
|
||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, inputs=None, retain_graph=False):
|
||||||
assert not (
|
assert not (
|
||||||
self._partition_grads and not self.require_grad_sync
|
self._partition_grads and not self.require_grad_sync
|
||||||
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
||||||
@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||||
|
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if not self.require_grad_sync:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
@ -427,14 +427,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
get_accelerator().synchronize()
|
get_accelerator().synchronize()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
assert not (
|
assert not (
|
||||||
self._partition_grads and not self.require_grad_sync
|
self._partition_grads and not self.require_grad_sync
|
||||||
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||||
|
|
||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(
|
||||||
|
tensor,
|
||||||
|
grad,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.require_grad_sync:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
|
@ -19,6 +19,8 @@ from colossalai.logging import disable_existing_loggers
|
|||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import assert_loose_close
|
from tests.test_moe.moe_utils import assert_loose_close
|
||||||
@ -751,12 +753,13 @@ def run_with_hybridplugin(test_config):
|
|||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(0, 1, 4, 1, 1),
|
(0, 1, 4, 1, 1),
|
||||||
# (0, 2, 2, 1, 1),
|
(1, 2, 2, 1, 1),
|
||||||
# (0, 2, 1, 2, 1),
|
(1, 2, 1, 2, 1),
|
||||||
# (0, 2, 1, 1, 2),
|
(1, 2, 1, 1, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
|
test_config = config
|
||||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||||
num_microbatches = pp_size
|
num_microbatches = pp_size
|
||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
@ -865,8 +868,15 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
)
|
)
|
||||||
# stage 0 chunk 0
|
# stage 0 chunk 0
|
||||||
parallel_output = None
|
parallel_output = None
|
||||||
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
|
if (
|
||||||
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
|
):
|
||||||
parallel_output = sharded_output["loss"]
|
parallel_output = sharded_output["loss"]
|
||||||
|
else:
|
||||||
|
parallel_output = torch.tensor(12345.0, device="cuda")
|
||||||
|
# broadcast along pp axis
|
||||||
|
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# for test without pp
|
# for test without pp
|
||||||
@ -874,7 +884,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
# dist.all_reduce(parallel_output, group=plugin.dp_group)
|
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
@ -891,8 +901,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
p.grad /= dp_size
|
p.grad /= dp_size
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
Loading…
Reference in New Issue
Block a user