mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[feat] add test run_fwd_bwd automatic scheduling;
This commit is contained in:
@@ -176,6 +176,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_forward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
@@ -188,6 +189,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
else:
|
||||
@@ -200,7 +202,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, []
|
||||
|
||||
################
|
||||
@@ -214,7 +216,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
||||
@@ -240,6 +242,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, []
|
||||
|
||||
################
|
||||
@@ -252,6 +255,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
else:
|
||||
@@ -261,6 +265,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_backward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
@@ -268,16 +273,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# self.comm.recv_backward recv bwd from prev stage;
|
||||
################
|
||||
else:
|
||||
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
|
||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
@@ -291,6 +296,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 && is_last_stage
|
||||
@@ -330,7 +336,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
@@ -359,6 +365,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# Send dx to PREV stage;
|
||||
################
|
||||
else:
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
@@ -371,6 +378,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# hold dy to local_send_bwd_buffer;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
self.local_send_backward_buffer.append(input_tensor_grad)
|
||||
return []
|
||||
|
||||
@@ -379,6 +387,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# Send dx to NEXT stage;
|
||||
################
|
||||
else:
|
||||
print(
|
||||
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
|
||||
)
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
||||
@@ -413,6 +425,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
@@ -463,6 +476,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# commom bwd step
|
||||
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
|
||||
# BUG:output_obj_grad is None
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
@@ -505,14 +519,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
# Step1: recv fwd
|
||||
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# # first layer
|
||||
# input_obj = input_obj
|
||||
# else:
|
||||
# # other layer
|
||||
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# # print(f"recv input_obj {input_obj}")
|
||||
# _wait_p2p(wait_handles)
|
||||
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# first layer
|
||||
input_obj = input_obj
|
||||
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
|
||||
else:
|
||||
# other layer
|
||||
input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# print(f"recv input_obj {input_obj}")
|
||||
_wait_p2p(wait_handles)
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
@@ -522,6 +543,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||
|
||||
# add input and output object for backward b
|
||||
@@ -532,7 +554,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
# add output to send_fwd_buffer
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
@@ -545,17 +569,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
# Step1: recv bwd
|
||||
# not first stage and chunk 1
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad, recv_bwd_handles = None, []
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
else:
|
||||
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
# # not first stage and chunk 1
|
||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# output_tensor_grad, recv_bwd_handles = None, []
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
# else:
|
||||
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors[model_chunk_id].pop()
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
# save output_tensor_grad for dw
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
@@ -565,9 +592,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
_wait_p2p(recv_bwd_handles)
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||
# Step2: bwd step
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
|
||||
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
@@ -576,23 +606,23 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
)
|
||||
print(f"input_object_grad {input_object_grad}")
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
|
||||
|
||||
# Step3: send bwd
|
||||
send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
scheduled_node,
|
||||
non_w_pending,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
):
|
||||
|
||||
# get y & dy from buffer
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop()
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
@@ -605,6 +635,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
input_obj: Optional[dict],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
@@ -615,19 +646,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
print(f"it {it}; scheduled_node {scheduled_node};")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward()
|
||||
self.recv_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "RECV_BACKWARD":
|
||||
self.recv_backward()
|
||||
self.recv_backward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward()
|
||||
self.send_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_BACKWARD":
|
||||
self.send_backward()
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f()
|
||||
self.send_backward(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=return_loss,
|
||||
outputs=return_outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
self.schedule_b()
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
)
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w()
|
||||
self.schedule_w(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
)
|
||||
it += 1
|
||||
|
Reference in New Issue
Block a user