[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
Wenhao Chen
2023-12-22 10:44:00 +08:00
committed by GitHub
parent af952673f7
commit 4fa689fca1
15 changed files with 728 additions and 446 deletions

View File

@@ -4,6 +4,7 @@ from types import MethodType
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
DIM = 8
NUM_LAYER = 8
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(4, 8)
self.linear2 = nn.Linear(8, 4)
super().__init__()
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
for layer in self.layers:
x = layer(x)
return x
def pp_linear_fwd(
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None
forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)}
@@ -38,34 +44,45 @@ def pp_linear_fwd(
return {"input_obj": forward(input_obj)}
def examine_pp():
def examine_pp(num_microbatch: int, batch_size: int):
"""
This test is to examine the correctness of 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
world_size = dist.get_world_size()
dist.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = 4
BATCH_SIZE = 4
# create models
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
pg_mesh = ProcessGroupMesh(world_size)
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
sharded_model = sub_model.cuda()
rank = dist.get_rank()
sharded_model = torch.nn.ModuleList()
num_local_layer = NUM_LAYER // world_size
for idx, sub_model in enumerate(pp_model.layers):
if idx // num_local_layer == rank:
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_local_layer
sharded_model._forward = sharded_model.forward
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward)
def custom_fwd(self, x):
for layer in self._modules.values():
x = layer(x)
return x
sharded_model._forward = MethodType(custom_fwd, sharded_model)
sharded_model.forward = MethodType(
partial(
pp_linear_fwd,
stage_mgr=stage_manager,
),
sharded_model._forward,
)
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@@ -73,19 +90,15 @@ def examine_pp():
# create
seed_all(1453)
if stage_manager.is_first_stage():
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
input_list = [torch.rand(batch_size, DIM).cuda()]
dist.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x)
criterion = lambda x, *arg, **kwargs: (x * x).mean()
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _)
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
@@ -95,34 +108,66 @@ def examine_pp():
assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients
torch_grad = []
for torch_p in torch_model.parameters():
torch_grad.append(torch_p.grad.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param
torch_param = []
for torch_p in torch_model.parameters():
torch_param.append(torch_p.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(rank, world_size, port):
def run_dist(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp()
examine_pp(num_microbatch, batch_size)
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_pp():
spawn(run_dist, 2)
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
assert NUM_LAYER % world_size == 0
spawn(
run_dist,
world_size,
num_microbatch=num_microbatch,
batch_size=batch_size,
)
if __name__ == "__main__":
test_pp()
test_pp(num_microbatch=4, batch_size=4, world_size=4)