mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[feat] add optim backward_b_by_grad
This commit is contained in:
parent
b1419ef76a
commit
4c4b01b859
@ -58,6 +58,28 @@ class OptimizerWrapper:
|
|||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(tensor, grad)
|
||||||
|
|
||||||
|
def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
||||||
|
"""
|
||||||
|
Performs a backward pass for dx, we only calculate dx = w*dy here
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): y or loss of current chunk;
|
||||||
|
grad_tensors (Tensor): dy of current chunk;
|
||||||
|
input_obj (Tensor): x of current chunk;
|
||||||
|
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||||
|
"""
|
||||||
|
torch.autograd.backward(
|
||||||
|
tensors=tensor,
|
||||||
|
grad_tensors=grad_tensors,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
def backward_w_by_grad():
|
||||||
|
"""
|
||||||
|
Performs a backward pass for dw, we only calculate dw = x*dy here
|
||||||
|
"""
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""
|
"""
|
||||||
Returns the optimizer state.
|
Returns the optimizer state.
|
||||||
|
@ -413,7 +413,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
@ -447,7 +447,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||||
else:
|
else:
|
||||||
# commom bwd step
|
# commom bwd step
|
||||||
# BUG:output_obj_grad is None
|
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||||
)
|
)
|
||||||
@ -564,7 +563,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node,
|
scheduled_node,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
# input_obj: Optional[dict],
|
# input_obj: Optional[dict],
|
||||||
# output_obj: Union[dict, torch.Tensor],
|
# output_obj: Union[dict, torch.Tensor],
|
||||||
# output_obj_grad: Optional[dict],
|
# output_obj_grad: Optional[dict],
|
||||||
@ -614,7 +613,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
# optimizer: OptimizerWrapper,
|
optimizer=optimizer,
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_tensor_grad,
|
output_obj_grad=output_tensor_grad,
|
||||||
@ -715,6 +714,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
elif scheduled_node.type == "W":
|
elif scheduled_node.type == "W":
|
||||||
self.schedule_w(
|
self.schedule_w(
|
||||||
|
@ -9,6 +9,7 @@ from torch.testing import assert_close
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
@ -625,7 +626,148 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
num_model_chunk: int,
|
num_model_chunk: int,
|
||||||
):
|
):
|
||||||
pass
|
# init dist
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
pp_size = world_size
|
||||||
|
pg_mesh = ProcessGroupMesh(pp_size)
|
||||||
|
num_microbatch = num_microbatch
|
||||||
|
# stage_manager
|
||||||
|
stage_manager = PipelineStageManager(
|
||||||
|
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
h, a, s = 4096, 32, 1024
|
||||||
|
mem_f = 34 * h + 5 * a * s
|
||||||
|
mem_w = -32 * h
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
graph = PipelineGraph(
|
||||||
|
n_stage=world_size,
|
||||||
|
n_micro=num_microbatch,
|
||||||
|
f_cost=6,
|
||||||
|
b_cost=6,
|
||||||
|
w_cost=6,
|
||||||
|
c_cost=6,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
# max_mem=mem_f * (p * 2 + m_offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
zbv_schedule = graph.get_v_schedule()
|
||||||
|
|
||||||
|
scheduler = ZeroBubbleVPipeScheduler(
|
||||||
|
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
num_model_chunks=num_model_chunk,
|
||||||
|
num_microbatch=num_microbatch,
|
||||||
|
overlap_p2p=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init loss func
|
||||||
|
def criterion(x, *args, **kwargs):
|
||||||
|
return (x * x).mean()
|
||||||
|
|
||||||
|
# init model and input
|
||||||
|
batch_size = batch_size
|
||||||
|
num_layers = 8
|
||||||
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
|
in_dim = out_dim = 8
|
||||||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
|
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
|
|
||||||
|
input_base = [t.clone() for t in data_iter]
|
||||||
|
model_base = deepcopy(model)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
# layer 0 & 7 to chunk 0 on rank0
|
||||||
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
|
for idx, sub_model in enumerate(model.layers):
|
||||||
|
if idx == 0 or idx == 7:
|
||||||
|
local_chunk.append(sub_model)
|
||||||
|
elif rank == 1:
|
||||||
|
# layer 1 & 6 to chunk 1 on rank1
|
||||||
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
|
for idx, sub_model in enumerate(model.layers):
|
||||||
|
if idx == 1 or idx == 6:
|
||||||
|
local_chunk.append(sub_model)
|
||||||
|
elif rank == 2:
|
||||||
|
# layer 2 & 5 to chunk 2 on rank2
|
||||||
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
|
for idx, sub_model in enumerate(model.layers):
|
||||||
|
if idx == 2 or idx == 5:
|
||||||
|
local_chunk.append(sub_model)
|
||||||
|
else:
|
||||||
|
# layer 3 & 4 to chunk 3 on rank3
|
||||||
|
local_chunk = torch.nn.Sequential().to(rank)
|
||||||
|
for idx, sub_model in enumerate(model.layers):
|
||||||
|
if idx == 3 or idx == 4:
|
||||||
|
local_chunk.append(sub_model)
|
||||||
|
|
||||||
|
# init optimizer
|
||||||
|
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
|
||||||
|
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
scheduler.run_forward_backward(
|
||||||
|
model_chunk=local_chunk,
|
||||||
|
data_iter=iter(data_iter),
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer_pp,
|
||||||
|
return_loss=None,
|
||||||
|
return_outputs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Fwd bwd for base
|
||||||
|
##########################
|
||||||
|
# fwd & bwd
|
||||||
|
output_base = model_base(input_base[0])
|
||||||
|
loss_base = criterion(output_base)
|
||||||
|
loss_base.backward()
|
||||||
|
optimizer_base.step()
|
||||||
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# assert weight
|
||||||
|
##########################
|
||||||
|
if rank == 0:
|
||||||
|
# layer 0
|
||||||
|
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
|
||||||
|
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
|
||||||
|
# layer 7
|
||||||
|
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
|
||||||
|
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
|
||||||
|
if rank == 1:
|
||||||
|
# layer 1
|
||||||
|
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
|
||||||
|
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
|
||||||
|
# layer 6
|
||||||
|
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
||||||
|
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
||||||
|
if rank == 2:
|
||||||
|
# layer 2
|
||||||
|
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
||||||
|
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
|
||||||
|
# layer 5
|
||||||
|
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
||||||
|
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
||||||
|
if rank == 3:
|
||||||
|
# layer 3
|
||||||
|
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
||||||
|
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
|
||||||
|
# layer 4
|
||||||
|
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
|
||||||
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# assert optim state
|
||||||
|
##########################
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@ -634,8 +776,16 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
@pytest.mark.parametrize("num_model_chunk", [4])
|
@pytest.mark.parametrize("num_model_chunk", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
||||||
|
# spawn(
|
||||||
|
# run_fwd_bwd_with_vschedule,
|
||||||
|
# nprocs=4,
|
||||||
|
# num_microbatch=num_microbatch,
|
||||||
|
# batch_size=batch_size,
|
||||||
|
# num_model_chunk=num_model_chunk,
|
||||||
|
# )
|
||||||
|
|
||||||
spawn(
|
spawn(
|
||||||
run_fwd_bwd_with_vschedule,
|
run_fwd_bwd_vschedule_with_optim,
|
||||||
nprocs=4,
|
nprocs=4,
|
||||||
num_microbatch=num_microbatch,
|
num_microbatch=num_microbatch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user