mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-03 18:46:43 +00:00
[feat] fix optimizer bwd b & w; support return accum loss & output
This commit is contained in:
parent
4c4b01b859
commit
48ba22dbfd
@ -58,7 +58,7 @@ class OptimizerWrapper:
|
|||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(tensor, grad)
|
||||||
|
|
||||||
def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
def backward_b_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass for dx, we only calculate dx = w*dy here
|
Performs a backward pass for dx, we only calculate dx = w*dy here
|
||||||
|
|
||||||
@ -69,16 +69,28 @@ class OptimizerWrapper:
|
|||||||
retain_graph (bool): default to be True, we retain graph in backward_b
|
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||||
"""
|
"""
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
tensors=tensor,
|
tensors=tensors,
|
||||||
grad_tensors=grad_tensors,
|
grad_tensors=grad_tensors,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
retain_graph=retain_graph,
|
retain_graph=retain_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward_w_by_grad():
|
def backward_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass for dw, we only calculate dw = x*dy here
|
Performs a backward pass for dw, we only calculate dw = x*dy here
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): y or loss of current chunk;
|
||||||
|
grad_tensors (Tensor): dy of current chunk;
|
||||||
|
input_obj (Tensor): w;
|
||||||
|
retain_graph (bool): default to be False, we release graph in backward_w
|
||||||
"""
|
"""
|
||||||
|
torch.autograd.backward(
|
||||||
|
tensors=tensors,
|
||||||
|
grad_tensors=grad_tensors,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""
|
"""
|
||||||
|
@ -13,7 +13,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device
|
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
|
||||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||||
@ -51,8 +51,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.schedules = schedule
|
self.schedules = schedule
|
||||||
# TODO: optim post valid
|
# TODO: optim post valid
|
||||||
self.do_post_validation = False
|
self.do_post_validation = False
|
||||||
self.is_first_run = True
|
# self.is_first_run = True
|
||||||
self.optimizer = None
|
# self.optimizer = None
|
||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
# self.enable_metadata_cache = enable_metadata_cache
|
# self.enable_metadata_cache = enable_metadata_cache
|
||||||
@ -405,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs.append(tree_map(detach, output_obj))
|
outputs.append(tree_map(detach, output_obj))
|
||||||
|
# print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}")
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return output_obj
|
return output_obj
|
||||||
@ -438,17 +439,36 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# bwd step
|
# bwd step
|
||||||
torch.autograd.backward(
|
# torch.autograd.backward(
|
||||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||||
|
# )
|
||||||
|
optimizer.backward_b_by_grad(
|
||||||
|
tensors=output_obj,
|
||||||
|
grad_tensors=output_obj_grad,
|
||||||
|
inputs=input_obj,
|
||||||
|
retain_graph=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss
|
# loss backward; output_obj is loss
|
||||||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
# torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True)
|
||||||
|
optimizer.backward_b_by_grad(
|
||||||
|
tensors=output_obj,
|
||||||
|
grad_tensors=None,
|
||||||
|
inputs=input_obj,
|
||||||
|
retain_graph=True,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# commom bwd step
|
# commom bwd step
|
||||||
torch.autograd.backward(
|
# torch.autograd.backward(
|
||||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||||
|
# )
|
||||||
|
optimizer.backward_b_by_grad(
|
||||||
|
tensors=output_obj,
|
||||||
|
grad_tensors=output_obj_grad,
|
||||||
|
inputs=input_obj,
|
||||||
|
retain_graph=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_obj.grad
|
return input_obj.grad
|
||||||
@ -457,7 +477,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
):
|
):
|
||||||
@ -475,15 +495,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
# calculate bwd w step ; only dw = x*dy;
|
# calculate bwd w step ; only dw = x*dy;
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
torch.autograd.backward(
|
# torch.autograd.backward(
|
||||||
|
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||||
|
# )
|
||||||
|
optimizer.backward_w_by_grad(
|
||||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
|
# torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||||
|
optimizer.backward_w_by_grad(
|
||||||
|
tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
torch.autograd.backward(
|
# torch.autograd.backward(
|
||||||
|
# tensors=output_obj,
|
||||||
|
# grad_tensors=output_obj_grad,
|
||||||
|
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
|
# )
|
||||||
|
|
||||||
|
optimizer.backward_w_by_grad(
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
grad_tensors=output_obj_grad,
|
grad_tensors=output_obj_grad,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
@ -535,7 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
@ -641,7 +672,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node,
|
scheduled_node,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
):
|
):
|
||||||
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
||||||
|
|
||||||
@ -660,7 +691,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.backward_w_step(
|
self.backward_w_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer=optimizer,
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_obj_grad,
|
output_obj_grad=output_obj_grad,
|
||||||
)
|
)
|
||||||
@ -677,16 +708,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||||
"""
|
"""
|
||||||
# # prepare batch
|
# prepare batch
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
print(
|
print(
|
||||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# prepare accum loss & output
|
||||||
|
accum_loss = None
|
||||||
|
|
||||||
|
# reset accum loss at fwd end;
|
||||||
|
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||||
|
|
||||||
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
while it < len(self.schedules):
|
while it < len(self.schedules):
|
||||||
scheduled_node = self.schedules[it]
|
scheduled_node = self.schedules[it]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
||||||
)
|
)
|
||||||
@ -706,8 +747,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
accum_loss=return_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=return_outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
elif scheduled_node.type == "B":
|
elif scheduled_node.type == "B":
|
||||||
self.schedule_b(
|
self.schedule_b(
|
||||||
@ -721,5 +762,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
it += 1
|
it += 1
|
||||||
|
|
||||||
|
# return loss & output
|
||||||
|
if outputs is not None:
|
||||||
|
outputs = merge_batch(outputs)
|
||||||
|
return {"loss": accum_loss, "outputs": outputs}
|
||||||
|
@ -672,7 +672,7 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
batch_size = batch_size
|
batch_size = batch_size
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
in_dim = out_dim = 8
|
in_dim = out_dim = 16
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
@ -714,15 +714,17 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
scheduler.run_forward_backward(
|
result = scheduler.run_forward_backward(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
data_iter=iter(data_iter),
|
data_iter=iter(data_iter),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer_pp,
|
optimizer=optimizer_pp,
|
||||||
return_loss=None,
|
return_loss=True,
|
||||||
return_outputs=None,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optimizer_pp.step()
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
@ -733,6 +735,15 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
optimizer_base.step()
|
optimizer_base.step()
|
||||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# assert loss & output
|
||||||
|
##########################
|
||||||
|
# only chunk 1 stage 0 hold loss and output
|
||||||
|
if rank == 0:
|
||||||
|
assert_close(result["loss"], loss_base)
|
||||||
|
assert_close(result["outputs"], output_base)
|
||||||
|
|
||||||
|
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
||||||
##########################
|
##########################
|
||||||
# assert weight
|
# assert weight
|
||||||
##########################
|
##########################
|
||||||
@ -768,6 +779,18 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
##########################
|
##########################
|
||||||
# assert optim state
|
# assert optim state
|
||||||
##########################
|
##########################
|
||||||
|
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0]
|
||||||
|
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0]
|
||||||
|
|
||||||
|
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()):
|
||||||
|
if key_base == key_pp:
|
||||||
|
if key_base != "params":
|
||||||
|
assert val_base == val_pp
|
||||||
|
else:
|
||||||
|
# BUG:
|
||||||
|
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
|
||||||
|
# params pp: [0, 1];
|
||||||
|
assert val_base[:2] == val_pp
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
Loading…
Reference in New Issue
Block a user