mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[feat] fix optimizer bwd b & w; support return accum loss & output
This commit is contained in:
@@ -13,7 +13,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
@@ -51,8 +51,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.schedules = schedule
|
||||
# TODO: optim post valid
|
||||
self.do_post_validation = False
|
||||
self.is_first_run = True
|
||||
self.optimizer = None
|
||||
# self.is_first_run = True
|
||||
# self.optimizer = None
|
||||
|
||||
# P2PMeta cache
|
||||
# self.enable_metadata_cache = enable_metadata_cache
|
||||
@@ -405,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
# print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}")
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
@@ -438,17 +439,36 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
if model_chunk_id == 0:
|
||||
# bwd step
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# )
|
||||
optimizer.backward_b_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||
# torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True)
|
||||
optimizer.backward_b_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=None,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
else:
|
||||
# commom bwd step
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# )
|
||||
optimizer.backward_b_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
return input_obj.grad
|
||||
@@ -457,7 +477,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
@@ -475,15 +495,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
if model_chunk_id == 0:
|
||||
torch.autograd.backward(
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
# )
|
||||
optimizer.backward_w_by_grad(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
)
|
||||
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||
# torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||
optimizer.backward_w_by_grad(
|
||||
tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
)
|
||||
else:
|
||||
torch.autograd.backward(
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj,
|
||||
# grad_tensors=output_obj_grad,
|
||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
# )
|
||||
|
||||
optimizer.backward_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
@@ -535,7 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
@@ -641,7 +672,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
):
|
||||
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
||||
|
||||
@@ -660,7 +691,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer=optimizer,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_obj_grad,
|
||||
)
|
||||
@@ -677,16 +708,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||
"""
|
||||
# # prepare batch
|
||||
# prepare batch
|
||||
self.load_batch(data_iter)
|
||||
print(
|
||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||
)
|
||||
|
||||
# prepare accum loss & output
|
||||
accum_loss = None
|
||||
|
||||
# reset accum loss at fwd end;
|
||||
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
print(
|
||||
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
||||
)
|
||||
@@ -706,8 +747,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
criterion=criterion,
|
||||
accum_loss=return_loss,
|
||||
outputs=return_outputs,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
self.schedule_b(
|
||||
@@ -721,5 +762,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
it += 1
|
||||
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
Reference in New Issue
Block a user