mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 20:10:08 +00:00
[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
This commit is contained in:
parent
9e0bd1af00
commit
8b37323f16
@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node,
|
scheduled_node,
|
||||||
model_chunk: torch.nn.ModuleList,
|
model_chunk: torch.nn.ModuleList,
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
input_obj: Optional[dict],
|
|
||||||
criterion: Callable,
|
criterion: Callable,
|
||||||
accum_loss: Optional[torch.Tensor] = None,
|
accum_loss: Optional[torch.Tensor] = None,
|
||||||
outputs: Optional[List[Any]] = None,
|
outputs: Optional[List[Any]] = None,
|
||||||
@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node:
|
scheduled_node:
|
||||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
model_chunk_id (int): The current model chunk idx;
|
model_chunk_id (int): The current model chunk idx;
|
||||||
input_obj (Optional[dict]): x;
|
|
||||||
criterion (Callable): loss function;
|
criterion (Callable): loss function;
|
||||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||||
@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from func param
|
# is first stage; get input from func param
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = input_obj
|
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
else:
|
else:
|
||||||
@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
def run_forward_backward(
|
def run_forward_backward(
|
||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
input_obj: Optional[dict],
|
|
||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
criterion: Callable[..., Any],
|
criterion: Callable[..., Any],
|
||||||
optimizer: Optional[OptimizerWrapper] = None,
|
optimizer: Optional[OptimizerWrapper] = None,
|
||||||
@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
# # prepare batch
|
# # prepare batch
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
|
print(
|
||||||
|
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
input_obj=input_obj,
|
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
accum_loss=return_loss,
|
accum_loss=return_loss,
|
||||||
outputs=return_outputs,
|
outputs=return_outputs,
|
||||||
|
@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return num_params, num_params_trainable
|
return num_params, num_params_trainable
|
||||||
|
|
||||||
|
|
||||||
# Test run_forward_backward with baseline;
|
# Test iter input & multiple microbatch
|
||||||
def test_run_fwd_bwd_base(
|
def test_run_fwd_bwd_iter_input(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
@ -47,7 +47,7 @@ def test_run_fwd_bwd_base(
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
pp_size = world_size
|
pp_size = world_size
|
||||||
pg_mesh = ProcessGroupMesh(pp_size)
|
pg_mesh = ProcessGroupMesh(pp_size)
|
||||||
|
num_microbatch = 4
|
||||||
# stage_manager
|
# stage_manager
|
||||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
||||||
|
|
||||||
@ -55,6 +55,7 @@ def test_run_fwd_bwd_base(
|
|||||||
zbv_schedule = [
|
zbv_schedule = [
|
||||||
# stage 0
|
# stage 0
|
||||||
[
|
[
|
||||||
|
# microbatch 0
|
||||||
# chunk 0 fwd
|
# chunk 0 fwd
|
||||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
|
||||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
|
||||||
@ -73,9 +74,67 @@ def test_run_fwd_bwd_base(
|
|||||||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
|
||||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
|
||||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||||
|
# microbatch 1
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||||
|
# microbatch 2
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||||
|
# microbatch 3
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||||
],
|
],
|
||||||
# stage 1
|
# stage 1
|
||||||
[
|
[
|
||||||
|
# microbatch 0
|
||||||
# chunk 0 fwd
|
# chunk 0 fwd
|
||||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
|
||||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
|
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
|
||||||
@ -94,9 +153,67 @@ def test_run_fwd_bwd_base(
|
|||||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
|
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
|
||||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
|
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
|
||||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||||
|
# microbatch 1
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||||
|
# microbatch 2
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||||
|
# microbatch 3
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||||
],
|
],
|
||||||
# stage 2
|
# stage 2
|
||||||
[
|
[
|
||||||
|
# microbatch 0
|
||||||
# chunk 0 fwd
|
# chunk 0 fwd
|
||||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
|
||||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
|
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
|
||||||
@ -114,10 +231,68 @@ def test_run_fwd_bwd_base(
|
|||||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
|
||||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
|
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
|
||||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
|
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
|
||||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
|
||||||
|
# microbatch 1
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
|
||||||
|
# microbatch 2
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
|
||||||
|
# microbatch 3
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
|
||||||
],
|
],
|
||||||
# stage 3
|
# stage 3
|
||||||
[
|
[
|
||||||
|
# microbatch 0
|
||||||
# chunk 0 fwd
|
# chunk 0 fwd
|
||||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
|
||||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
|
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
|
||||||
@ -136,6 +311,63 @@ def test_run_fwd_bwd_base(
|
|||||||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
|
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
|
||||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
|
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
|
||||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
|
||||||
|
# microbatch 1
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
|
||||||
|
# microbatch 2
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
|
||||||
|
# microbatch 3
|
||||||
|
# chunk 0 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
|
||||||
|
# chunk 1 fwd
|
||||||
|
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
|
||||||
|
# chunk 1 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
|
||||||
|
# chunk 0 bwd
|
||||||
|
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
|
||||||
|
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -143,7 +375,7 @@ def test_run_fwd_bwd_base(
|
|||||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
num_model_chunks=pp_size,
|
num_model_chunks=pp_size,
|
||||||
num_microbatch=1,
|
num_microbatch=num_microbatch,
|
||||||
overlap_p2p=False,
|
overlap_p2p=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,14 +384,15 @@ def test_run_fwd_bwd_base(
|
|||||||
return (x * x).mean()
|
return (x * x).mean()
|
||||||
|
|
||||||
# init model and input
|
# init model and input
|
||||||
|
batch_size = 4
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
in_dim = out_dim = 8
|
in_dim = out_dim = 8
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||||
# data_iter = [input0]
|
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
|
|
||||||
input_base = input0.clone()
|
[t.clone() for t in data_iter]
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -193,9 +426,7 @@ def test_run_fwd_bwd_base(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
scheduler.run_forward_backward(
|
scheduler.run_forward_backward(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
input_obj=input0,
|
data_iter=iter(data_iter),
|
||||||
# data_iter=iter(data_iter),
|
|
||||||
data_iter=None,
|
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
return_loss=None,
|
return_loss=None,
|
||||||
@ -206,8 +437,7 @@ def test_run_fwd_bwd_base(
|
|||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
# fwd & bwd
|
# fwd & bwd
|
||||||
output_base = model_base(input_base)
|
output_base = model_base(data_iter[0])
|
||||||
# loss_base = output_base.mean()
|
|
||||||
loss_base = criterion(output_base)
|
loss_base = criterion(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
@ -245,15 +475,6 @@ def test_run_fwd_bwd_base(
|
|||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
# Test iter input & multiple microbatch
|
|
||||||
def test_run_fwd_bwd_iter_input(
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
port: int,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||||
# @pytest.mark.parametrize("batch_size", [4])
|
# @pytest.mark.parametrize("batch_size", [4])
|
||||||
@ -261,7 +482,7 @@ def test_run_fwd_bwd_iter_input(
|
|||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp():
|
def test_pp():
|
||||||
spawn(
|
spawn(
|
||||||
test_run_fwd_bwd_base,
|
test_run_fwd_bwd_iter_input,
|
||||||
nprocs=4,
|
nprocs=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user