[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)

* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------

Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
This commit is contained in:
Wenhao Chen
2023-11-28 16:54:42 +08:00
committed by GitHub
parent 126cf180bc
commit 7172459e74
39 changed files with 4007 additions and 234 deletions

View File

@@ -4,6 +4,7 @@ from types import MethodType
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
NUM_LAYER = 8
DIM = 4
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(4, 8)
self.linear2 = nn.Linear(8, 8)
self.linear3 = nn.Linear(8, 8)
self.linear4 = nn.Linear(8, 8)
self.linear5 = nn.Linear(8, 8)
self.linear6 = nn.Linear(8, 8)
self.linear7 = nn.Linear(8, 8)
self.linear8 = nn.Linear(8, 4)
super().__init__()
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
x = self.linear7(x)
x = self.linear8(x)
for layer in self.layers:
x = layer(x)
return x
@@ -44,70 +35,71 @@ def pp_linear_fwd(
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
num_chunks: int = None,
model_chunk_id: int = None,
):
if stage_mgr.is_first_stage() and model_chunk_id == 0:
if stage_mgr.is_first_stage(model_chunk_id):
return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
elif stage_mgr.is_last_stage(model_chunk_id):
return forward(input_obj)
else:
return {"input_obj": forward(input_obj)}
@parameterize("num_micro_batches", [4, 8, 12])
def examine_pp(num_micro_batches):
def run_pp(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = num_micro_batches
BATCH_SIZE = num_micro_batches
NUM_CHUNKS = 2
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
# create model
seed_all(1453)
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
pg_mesh = ProcessGroupMesh(world_size)
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
schedule = InterleavedSchedule(
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
)
sharded_model = torch.nn.ModuleList()
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
for idx, sub_model in enumerate(pp_model.layers):
if idx % world_size == rank:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(
pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model)
),
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
sub_model._forward,
)
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))
# create
seed_all(1453)
if local_rank == 0:
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
# create data
seed_all(115)
input_list = [torch.rand(batch_size, DIM).cuda()]
dist.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x)
def criterion(x, *args, **kwargs):
return (x * x).mean()
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _)
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
@@ -115,45 +107,41 @@ def examine_pp(num_micro_batches):
)
# check loss
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(-1):
assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients
torch_grad = []
for torch_p in torch_model.parameters():
torch_grad.append(torch_p.grad.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
else:
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
pp_optimizer.step()
# check updated param
torch_param = []
for torch_p in torch_model.parameters():
torch_param.append(torch_p.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
else:
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp()
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("num_model_chunk", [2, 4])
@rerun_if_address_is_in_use()
def test_pp():
spawn(run_dist, 4)
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
assert NUM_LAYER % num_model_chunk == 0
spawn(
run_pp,
nprocs=NUM_LAYER // num_model_chunk,
num_microbatch=num_microbatch,
batch_size=batch_size,
num_model_chunk=num_model_chunk,
)
if __name__ == "__main__":
test_pp()
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)