mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[feat] add dw test;
This commit is contained in:
@@ -64,8 +64,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def _free_buffers(self):
|
||||
# free local buffer
|
||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||
|
||||
# x & y buffer for schedule b
|
||||
self.input_tensors = [[], []]
|
||||
self.output_tensors = [[], []]
|
||||
|
||||
# y & dy buffer for schedule b
|
||||
self.output_tensors_dw = [[], []]
|
||||
self.output_tensors_grad_dw = [[], []]
|
||||
|
||||
self.send_forward_buffer = [[], []]
|
||||
self.recv_forward_buffer = [[], []]
|
||||
self.send_backward_buffer = [[], []]
|
||||
@@ -467,7 +474,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
# input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
@@ -479,8 +486,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters()))
|
||||
|
||||
torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||
else:
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj,
|
||||
@@ -518,10 +524,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
)
|
||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||
|
||||
# add input and output object for backward
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
|
||||
# add output object for backward w
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
|
||||
@@ -544,10 +553,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
|
||||
# get input and output object from buffer
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors[model_chunk_id].pop()
|
||||
|
||||
# save output_tensor_grad for dw
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# we save loss here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
_wait_p2p(recv_bwd_handles)
|
||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||
# Step2: bwd step
|
||||
@@ -571,15 +588,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
|
||||
# get y & dy from buffer
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop()
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
|
||||
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_obj_grad,
|
||||
)
|
||||
|
Reference in New Issue
Block a user