mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-05 15:44:49 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -37,12 +37,13 @@ def pp_linear_fwd(
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
model_chunk_id: int = None,
|
||||
):
|
||||
if stage_mgr.is_first_stage(model_chunk_id):
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage(model_chunk_id):
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||
if stage_mgr.is_first_stage():
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage():
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
|
||||
|
||||
def run_pp(
|
||||
@@ -107,7 +108,7 @@ def run_pp(
|
||||
)
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage(-1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
@@ -119,6 +120,7 @@ def run_pp(
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
pp_optimizer.step()
|
||||
pp_optimizer.zero_grad()
|
||||
|
||||
# check updated param
|
||||
for i in range(num_model_chunk):
|
||||
@@ -126,6 +128,24 @@ def run_pp(
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||
|
||||
Reference in New Issue
Block a user