ColossalAI/tests/test_pipeline/test_schedule/test_interleaved.py
Edenzzzz f5c84af0b0
[Feature] Zigzag Ring attention (#5905)
* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-16 13:56:38 +08:00

167 lines
5.3 KiB
Python

import copy
from functools import partial
from types import MethodType
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
NUM_LAYER = 8
DIM = 4
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage():
return forward(input_obj)
else:
return {"input_obj": forward(input_obj)}
def run_pp(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
# create model
seed_all(1453)
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
pg_mesh = ProcessGroupMesh(world_size)
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
schedule = InterleavedSchedule(
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
)
sharded_model = torch.nn.ModuleList()
for idx, sub_model in enumerate(pp_model.layers):
if idx % world_size == rank:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
sub_model._forward,
)
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))
# create data
seed_all(115)
input_list = [torch.rand(batch_size, DIM).cuda()]
dist.all_reduce(input_list[0])
def criterion(x, *args, **kwargs):
return (x * x).mean()
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
# check loss
if stage_manager.is_last_stage(ignore_chunk=True):
assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(num_model_chunk):
idx = world_size * i + rank
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param
for i in range(num_model_chunk):
idx = world_size * i + rank
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("num_model_chunk", [2, 4])
@rerun_if_address_is_in_use()
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
assert NUM_LAYER % num_model_chunk == 0
spawn(
run_pp,
nprocs=NUM_LAYER // num_model_chunk,
num_microbatch=num_microbatch,
batch_size=batch_size,
num_model_chunk=num_model_chunk,
)
if __name__ == "__main__":
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)