mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -5,43 +5,69 @@ import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
WORLD_SIZE = 2
|
||||
|
||||
|
||||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(2)
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_accelerator().get_current_device())
|
||||
data = [
|
||||
"tensor",
|
||||
tensor,
|
||||
[tensor],
|
||||
{"tensor": tensor},
|
||||
]
|
||||
|
||||
if rank == 0:
|
||||
p2p.send_forward(tensor)
|
||||
p2p.send_forward([tensor])
|
||||
p2p.send_forward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_forward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
p2p.send_backward(tensor)
|
||||
p2p.send_backward([tensor])
|
||||
p2p.send_backward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_backward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=create_send_metadata(tensor),
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -52,7 +78,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipeline_p2p():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, WORLD_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -4,6 +4,7 @@ from types import MethodType
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
|
||||
NUM_LAYER = 8
|
||||
DIM = 4
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(4, 8)
|
||||
self.linear2 = nn.Linear(8, 8)
|
||||
self.linear3 = nn.Linear(8, 8)
|
||||
self.linear4 = nn.Linear(8, 8)
|
||||
self.linear5 = nn.Linear(8, 8)
|
||||
self.linear6 = nn.Linear(8, 8)
|
||||
self.linear7 = nn.Linear(8, 8)
|
||||
self.linear8 = nn.Linear(8, 4)
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.linear4(x)
|
||||
x = self.linear5(x)
|
||||
x = self.linear6(x)
|
||||
x = self.linear7(x)
|
||||
x = self.linear8(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -44,70 +35,72 @@ def pp_linear_fwd(
|
||||
data: torch.Tensor = None,
|
||||
input_obj: torch.Tensor = None,
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
num_chunks: int = None,
|
||||
model_chunk_id: int = None,
|
||||
):
|
||||
if stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||
if stage_mgr.is_first_stage():
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage():
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
|
||||
|
||||
@parameterize("num_micro_batches", [4, 8, 12])
|
||||
def examine_pp(num_micro_batches):
|
||||
def run_pp(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
num_model_chunk: int,
|
||||
):
|
||||
"""
|
||||
This test is to examine the correctness of interleaved 1F1B, compared with torch.
|
||||
Be aware it contains some hardcodes.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
NUM_MICRO_BATCHS = num_micro_batches
|
||||
BATCH_SIZE = num_micro_batches
|
||||
NUM_CHUNKS = 2
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
# create model
|
||||
seed_all(1453)
|
||||
torch_model = MlpModel().cuda()
|
||||
|
||||
pp_model = copy.deepcopy(torch_model).cuda()
|
||||
|
||||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
|
||||
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
|
||||
pg_mesh = ProcessGroupMesh(world_size)
|
||||
stage_manager = PipelineStageManager(
|
||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||
)
|
||||
schedule = InterleavedSchedule(
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=num_model_chunk,
|
||||
num_microbatch=num_microbatch,
|
||||
)
|
||||
|
||||
sharded_model = torch.nn.ModuleList()
|
||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
||||
if idx % (world_size) == local_rank:
|
||||
for idx, sub_model in enumerate(pp_model.layers):
|
||||
if idx % world_size == rank:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(
|
||||
pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model)
|
||||
),
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
|
||||
sub_model._forward,
|
||||
)
|
||||
sharded_model.append(sub_model.cuda())
|
||||
assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"
|
||||
|
||||
# create optimizer
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
|
||||
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))
|
||||
|
||||
# create
|
||||
seed_all(1453)
|
||||
if local_rank == 0:
|
||||
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
|
||||
else:
|
||||
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
|
||||
torch.distributed.all_reduce(input_list[0])
|
||||
# create data
|
||||
seed_all(115)
|
||||
input_list = [torch.rand(batch_size, DIM).cuda()]
|
||||
dist.all_reduce(input_list[0])
|
||||
|
||||
criterion = lambda x, y: torch.mean(x)
|
||||
def criterion(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# forward and backward
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output, _)
|
||||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
@@ -115,45 +108,60 @@ def examine_pp(num_micro_batches):
|
||||
)
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
torch_grad = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_grad.append(torch_p.grad.data)
|
||||
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
if idx < 2:
|
||||
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
|
||||
else:
|
||||
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
pp_optimizer.step()
|
||||
pp_optimizer.zero_grad()
|
||||
|
||||
# check updated param
|
||||
torch_param = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_param.append(torch_p.data)
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
if idx < 2:
|
||||
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
|
||||
else:
|
||||
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
examine_pp()
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||
@pytest.mark.parametrize("batch_size", [12])
|
||||
@pytest.mark.parametrize("num_model_chunk", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
spawn(run_dist, 4)
|
||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
||||
assert NUM_LAYER % num_model_chunk == 0
|
||||
spawn(
|
||||
run_pp,
|
||||
nprocs=NUM_LAYER // num_model_chunk,
|
||||
num_microbatch=num_microbatch,
|
||||
batch_size=batch_size,
|
||||
num_model_chunk=num_model_chunk,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp()
|
||||
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
|
||||
|
@@ -4,6 +4,7 @@ from types import MethodType
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
|
||||
DIM = 8
|
||||
NUM_LAYER = 8
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(4, 8)
|
||||
self.linear2 = nn.Linear(8, 4)
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def pp_linear_fwd(
|
||||
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None
|
||||
forward,
|
||||
data: torch.Tensor = None,
|
||||
input_obj: torch.Tensor = None,
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
):
|
||||
if stage_mgr.is_first_stage():
|
||||
return {"input_obj": forward(data)}
|
||||
@@ -38,34 +44,45 @@ def pp_linear_fwd(
|
||||
return {"input_obj": forward(input_obj)}
|
||||
|
||||
|
||||
def examine_pp():
|
||||
def examine_pp(num_microbatch: int, batch_size: int):
|
||||
"""
|
||||
This test is to examine the correctness of 1F1B, compared with torch.
|
||||
Be aware it contains some hardcodes.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
local_rank = torch.distributed.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
dist.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
NUM_MICRO_BATCHS = 4
|
||||
BATCH_SIZE = 4
|
||||
|
||||
# create models
|
||||
torch_model = MlpModel().cuda()
|
||||
|
||||
pp_model = copy.deepcopy(torch_model).cuda()
|
||||
|
||||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
|
||||
pg_mesh = ProcessGroupMesh(world_size)
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
|
||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
|
||||
|
||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
||||
if idx % (world_size) == local_rank:
|
||||
sharded_model = sub_model.cuda()
|
||||
rank = dist.get_rank()
|
||||
sharded_model = torch.nn.ModuleList()
|
||||
num_local_layer = NUM_LAYER // world_size
|
||||
for idx, sub_model in enumerate(pp_model.layers):
|
||||
if idx // num_local_layer == rank:
|
||||
sharded_model.append(sub_model.cuda())
|
||||
assert len(sharded_model) == num_local_layer
|
||||
|
||||
sharded_model._forward = sharded_model.forward
|
||||
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward)
|
||||
def custom_fwd(self, x):
|
||||
for layer in self._modules.values():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
sharded_model._forward = MethodType(custom_fwd, sharded_model)
|
||||
sharded_model.forward = MethodType(
|
||||
partial(
|
||||
pp_linear_fwd,
|
||||
stage_mgr=stage_manager,
|
||||
),
|
||||
sharded_model._forward,
|
||||
)
|
||||
|
||||
# create optimizer
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
@@ -73,19 +90,15 @@ def examine_pp():
|
||||
|
||||
# create
|
||||
seed_all(1453)
|
||||
if stage_manager.is_first_stage():
|
||||
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
|
||||
else:
|
||||
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
|
||||
torch.distributed.all_reduce(input_list[0])
|
||||
input_list = [torch.rand(batch_size, DIM).cuda()]
|
||||
dist.all_reduce(input_list[0])
|
||||
|
||||
criterion = lambda x, y: torch.mean(x)
|
||||
criterion = lambda x, *arg, **kwargs: (x * x).mean()
|
||||
|
||||
# forward and backward
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output, _)
|
||||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
@@ -95,34 +108,66 @@ def examine_pp():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
torch_grad = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_grad.append(torch_p.grad.data)
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
pp_optimizer.step()
|
||||
pp_optimizer.zero_grad()
|
||||
|
||||
# check updated param
|
||||
torch_param = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_param.append(torch_p.data)
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
examine_pp()
|
||||
examine_pp(num_microbatch, batch_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4, 6])
|
||||
@pytest.mark.parametrize("batch_size", [12])
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
spawn(run_dist, 2)
|
||||
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
|
||||
assert NUM_LAYER % world_size == 0
|
||||
spawn(
|
||||
run_dist,
|
||||
world_size,
|
||||
num_microbatch=num_microbatch,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp()
|
||||
test_pp(num_microbatch=4, batch_size=4, world_size=4)
|
||||
|
Reference in New Issue
Block a user