mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 19:16:42 +00:00
[feat] add dw test;
This commit is contained in:
parent
ee9baedadf
commit
c18ef060cf
@ -64,8 +64,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
def _free_buffers(self):
|
def _free_buffers(self):
|
||||||
# free local buffer
|
# free local buffer
|
||||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||||
|
|
||||||
|
# x & y buffer for schedule b
|
||||||
self.input_tensors = [[], []]
|
self.input_tensors = [[], []]
|
||||||
self.output_tensors = [[], []]
|
self.output_tensors = [[], []]
|
||||||
|
|
||||||
|
# y & dy buffer for schedule b
|
||||||
|
self.output_tensors_dw = [[], []]
|
||||||
|
self.output_tensors_grad_dw = [[], []]
|
||||||
|
|
||||||
self.send_forward_buffer = [[], []]
|
self.send_forward_buffer = [[], []]
|
||||||
self.recv_forward_buffer = [[], []]
|
self.recv_forward_buffer = [[], []]
|
||||||
self.send_backward_buffer = [[], []]
|
self.send_backward_buffer = [[], []]
|
||||||
@ -467,7 +474,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
input_obj: Optional[dict],
|
# input_obj: Optional[dict],
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
):
|
):
|
||||||
@ -479,8 +486,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters()))
|
torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
@ -518,10 +524,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||||
|
|
||||||
# add input and output object for backward
|
# add input and output object for backward b
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
|
# add output object for backward w
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
# Step3: send fwd
|
# Step3: send fwd
|
||||||
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||||
|
|
||||||
@ -544,10 +553,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||||
|
|
||||||
# get input and output object from buffer
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop()
|
input_obj = self.input_tensors[model_chunk_id].pop()
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop()
|
output_obj = self.output_tensors[model_chunk_id].pop()
|
||||||
|
|
||||||
|
# save output_tensor_grad for dw
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
# we save loss here
|
||||||
|
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||||
|
else:
|
||||||
|
# we save output_tensor_grad here
|
||||||
|
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||||
|
|
||||||
_wait_p2p(recv_bwd_handles)
|
_wait_p2p(recv_bwd_handles)
|
||||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||||
# Step2: bwd step
|
# Step2: bwd step
|
||||||
@ -571,15 +588,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
input_obj: Optional[dict],
|
|
||||||
output_obj: Union[dict, torch.Tensor],
|
|
||||||
output_obj_grad: Optional[dict],
|
|
||||||
):
|
):
|
||||||
|
|
||||||
|
# get y & dy from buffer
|
||||||
|
output_obj = self.output_tensors_dw[model_chunk_id].pop()
|
||||||
|
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
|
||||||
|
|
||||||
self.backward_w_step(
|
self.backward_w_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
input_obj=input_obj,
|
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_obj_grad,
|
output_obj_grad=output_obj_grad,
|
||||||
)
|
)
|
||||||
|
@ -4,6 +4,7 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
@ -56,13 +57,13 @@ def test_zerobubble_pipeline_base(
|
|||||||
|
|
||||||
# init model and input
|
# init model and input
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
in_dim = out_dim = 2048
|
in_dim = out_dim = 8
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||||
|
|
||||||
input0.clone()
|
input_base = input0.clone()
|
||||||
deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# layer 0 & 7 to chunk 0 on rank0
|
# layer 0 & 7 to chunk 0 on rank0
|
||||||
@ -245,6 +246,13 @@ def test_zerobubble_pipeline_base(
|
|||||||
model_chunk_id=chunk_id,
|
model_chunk_id=chunk_id,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_0,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# # chunk 1 id 1 (layer 6) bwd
|
# # chunk 1 id 1 (layer 6) bwd
|
||||||
if rank == 1:
|
if rank == 1:
|
||||||
@ -255,6 +263,13 @@ def test_zerobubble_pipeline_base(
|
|||||||
model_chunk_id=chunk_id,
|
model_chunk_id=chunk_id,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_1,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# chunk 2 id 1 (layer 5) bwd
|
# chunk 2 id 1 (layer 5) bwd
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
@ -266,6 +281,14 @@ def test_zerobubble_pipeline_base(
|
|||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_2,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# chunk 3 id 1 (layer 4) bwd
|
# chunk 3 id 1 (layer 4) bwd
|
||||||
if rank == 3:
|
if rank == 3:
|
||||||
chunk_id = 1
|
chunk_id = 1
|
||||||
@ -276,6 +299,14 @@ def test_zerobubble_pipeline_base(
|
|||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_3,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# ######
|
# ######
|
||||||
# # bwd rank 1->4
|
# # bwd rank 1->4
|
||||||
# ######
|
# ######
|
||||||
@ -290,6 +321,13 @@ def test_zerobubble_pipeline_base(
|
|||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
# print(f"input_grad3 {input_grad3}")
|
# print(f"input_grad3 {input_grad3}")
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_3,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# chunk 2 id 0 (layer 2) bwd
|
# chunk 2 id 0 (layer 2) bwd
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
@ -301,6 +339,13 @@ def test_zerobubble_pipeline_base(
|
|||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
# print(f"input_grad2 {input_grad2}")
|
# print(f"input_grad2 {input_grad2}")
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_2,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# chunk 1 id 0 (layer 1) bwd
|
# chunk 1 id 0 (layer 1) bwd
|
||||||
if rank == 1:
|
if rank == 1:
|
||||||
@ -312,6 +357,14 @@ def test_zerobubble_pipeline_base(
|
|||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_1,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# chunk 0 id 0 (layer 0) bwd
|
# chunk 0 id 0 (layer 0) bwd
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
@ -323,6 +376,55 @@ def test_zerobubble_pipeline_base(
|
|||||||
)
|
)
|
||||||
# print(f"input_grad0 {input_grad0}")
|
# print(f"input_grad0 {input_grad0}")
|
||||||
|
|
||||||
|
scheduler.schedule_w(
|
||||||
|
scheduled_node=None,
|
||||||
|
non_w_pending=None,
|
||||||
|
model_chunk=chunk_0,
|
||||||
|
model_chunk_id=chunk_id,
|
||||||
|
# optimizer: OptimizerWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Fwd bwd for base
|
||||||
|
##########################
|
||||||
|
# fwd & bwd
|
||||||
|
output_base = model_base(input_base)
|
||||||
|
loss_base = output_base.mean()
|
||||||
|
loss_base.backward()
|
||||||
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
|
|
||||||
|
# assert weight
|
||||||
|
if rank == 0:
|
||||||
|
# layer 0
|
||||||
|
assert_close(chunk_0[0].weight, model_base.layers[0].weight)
|
||||||
|
assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad)
|
||||||
|
# layer 7
|
||||||
|
assert_close(chunk_0[1].weight, model_base.layers[7].weight)
|
||||||
|
assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad)
|
||||||
|
if rank == 1:
|
||||||
|
# layer 1
|
||||||
|
assert_close(chunk_1[0].weight, model_base.layers[1].weight)
|
||||||
|
assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad)
|
||||||
|
# layer 6
|
||||||
|
assert_close(chunk_1[1].weight, model_base.layers[6].weight)
|
||||||
|
assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad)
|
||||||
|
|
||||||
|
if rank == 2:
|
||||||
|
# layer 2
|
||||||
|
assert_close(chunk_2[0].weight, model_base.layers[2].weight)
|
||||||
|
assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad)
|
||||||
|
# layer 5
|
||||||
|
assert_close(chunk_2[1].weight, model_base.layers[5].weight)
|
||||||
|
assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad)
|
||||||
|
|
||||||
|
if rank == 3:
|
||||||
|
# layer 3
|
||||||
|
assert_close(chunk_3[0].weight, model_base.layers[3].weight)
|
||||||
|
assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad)
|
||||||
|
# layer 4
|
||||||
|
assert_close(chunk_3[1].weight, model_base.layers[4].weight)
|
||||||
|
assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.dist
|
# @pytest.mark.dist
|
||||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||||
|
Loading…
Reference in New Issue
Block a user