[feat] add dw test;

This commit is contained in:
duanjunwen
2024-08-23 06:04:12 +00:00
parent ee9baedadf
commit c18ef060cf
2 changed files with 132 additions and 12 deletions

View File

@@ -64,8 +64,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def _free_buffers(self):
# free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue
# x & y buffer for schedule b
self.input_tensors = [[], []]
self.output_tensors = [[], []]
# y & dy buffer for schedule b
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
self.send_forward_buffer = [[], []]
self.recv_forward_buffer = [[], []]
self.send_backward_buffer = [[], []]
@@ -467,7 +474,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
input_obj: Optional[dict],
# input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
@@ -479,8 +486,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters()))
torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
else:
torch.autograd.backward(
tensors=output_obj,
@@ -518,10 +524,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step3: send fwd
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
@@ -544,10 +553,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# print(f"recv output_tensor_grad {output_tensor_grad}")
# get input and output object from buffer
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop()
output_obj = self.output_tensors[model_chunk_id].pop()
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
_wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step
@@ -571,15 +588,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop()
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
self.backward_w_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
# optimizer: OptimizerWrapper,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_obj_grad,
)