[feat] add dw test;

This commit is contained in:
duanjunwen 2024-08-23 06:04:12 +00:00
parent ee9baedadf
commit c18ef060cf
2 changed files with 132 additions and 12 deletions

View File

@ -64,8 +64,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def _free_buffers(self): def _free_buffers(self):
# free local buffer # free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue # two dim array, first dim is the model chunk, second dim is the microbatch queue
# x & y buffer for schedule b
self.input_tensors = [[], []] self.input_tensors = [[], []]
self.output_tensors = [[], []] self.output_tensors = [[], []]
# y & dy buffer for schedule b
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
self.send_forward_buffer = [[], []] self.send_forward_buffer = [[], []]
self.recv_forward_buffer = [[], []] self.recv_forward_buffer = [[], []]
self.send_backward_buffer = [[], []] self.send_backward_buffer = [[], []]
@ -467,7 +474,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
input_obj: Optional[dict], # input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict], output_obj_grad: Optional[dict],
): ):
@ -479,8 +486,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters())) torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
else: else:
torch.autograd.backward( torch.autograd.backward(
tensors=output_obj, tensors=output_obj,
@ -518,10 +524,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}") # print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward # add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj) self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step3: send fwd # Step3: send fwd
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
@ -544,10 +553,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# print(f"recv output_tensor_grad {output_tensor_grad}") # print(f"recv output_tensor_grad {output_tensor_grad}")
# get input and output object from buffer # get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop() input_obj = self.input_tensors[model_chunk_id].pop()
output_obj = self.output_tensors[model_chunk_id].pop() output_obj = self.output_tensors[model_chunk_id].pop()
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
_wait_p2p(recv_bwd_handles) _wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}") # print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step # Step2: bwd step
@ -571,15 +588,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
): ):
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop()
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
self.backward_w_step( self.backward_w_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
input_obj=input_obj,
output_obj=output_obj, output_obj=output_obj,
output_obj_grad=output_obj_grad, output_obj_grad=output_obj_grad,
) )

View File

@ -4,6 +4,7 @@ from typing import Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -56,13 +57,13 @@ def test_zerobubble_pipeline_base(
# init model and input # init model and input
num_layers = 8 num_layers = 8
in_dim = out_dim = 2048 in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
input0.clone() input_base = input0.clone()
deepcopy(model) model_base = deepcopy(model)
if rank == 0: if rank == 0:
# layer 0 & 7 to chunk 0 on rank0 # layer 0 & 7 to chunk 0 on rank0
@ -245,6 +246,13 @@ def test_zerobubble_pipeline_base(
model_chunk_id=chunk_id, model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_0,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# # chunk 1 id 1 (layer 6) bwd # # chunk 1 id 1 (layer 6) bwd
if rank == 1: if rank == 1:
@ -255,6 +263,13 @@ def test_zerobubble_pipeline_base(
model_chunk_id=chunk_id, model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_1,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# chunk 2 id 1 (layer 5) bwd # chunk 2 id 1 (layer 5) bwd
if rank == 2: if rank == 2:
@ -266,6 +281,14 @@ def test_zerobubble_pipeline_base(
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_2,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# chunk 3 id 1 (layer 4) bwd # chunk 3 id 1 (layer 4) bwd
if rank == 3: if rank == 3:
chunk_id = 1 chunk_id = 1
@ -276,6 +299,14 @@ def test_zerobubble_pipeline_base(
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_3,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# ###### # ######
# # bwd rank 1->4 # # bwd rank 1->4
# ###### # ######
@ -290,6 +321,13 @@ def test_zerobubble_pipeline_base(
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
# print(f"input_grad3 {input_grad3}") # print(f"input_grad3 {input_grad3}")
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_3,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# chunk 2 id 0 (layer 2) bwd # chunk 2 id 0 (layer 2) bwd
if rank == 2: if rank == 2:
@ -301,6 +339,13 @@ def test_zerobubble_pipeline_base(
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
# print(f"input_grad2 {input_grad2}") # print(f"input_grad2 {input_grad2}")
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_2,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# chunk 1 id 0 (layer 1) bwd # chunk 1 id 0 (layer 1) bwd
if rank == 1: if rank == 1:
@ -312,6 +357,14 @@ def test_zerobubble_pipeline_base(
# optimizer: OptimizerWrapper, # optimizer: OptimizerWrapper,
) )
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_1,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
# chunk 0 id 0 (layer 0) bwd # chunk 0 id 0 (layer 0) bwd
if rank == 0: if rank == 0:
chunk_id = 0 chunk_id = 0
@ -323,6 +376,55 @@ def test_zerobubble_pipeline_base(
) )
# print(f"input_grad0 {input_grad0}") # print(f"input_grad0 {input_grad0}")
scheduler.schedule_w(
scheduled_node=None,
non_w_pending=None,
model_chunk=chunk_0,
model_chunk_id=chunk_id,
# optimizer: OptimizerWrapper,
)
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base)
loss_base = output_base.mean()
loss_base.backward()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# assert weight
if rank == 0:
# layer 0
assert_close(chunk_0[0].weight, model_base.layers[0].weight)
assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(chunk_0[1].weight, model_base.layers[7].weight)
assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(chunk_1[0].weight, model_base.layers[1].weight)
assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(chunk_1[1].weight, model_base.layers[6].weight)
assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(chunk_2[0].weight, model_base.layers[2].weight)
assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(chunk_2[1].weight, model_base.layers[5].weight)
assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(chunk_3[0].weight, model_base.layers[3].weight)
assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(chunk_3[1].weight, model_base.layers[4].weight)
assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad)
# @pytest.mark.dist # @pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("num_microbatch", [4])