[booster] fix no_sync method (#3709)

* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
This commit is contained in:
Hongxin Liu
2023-05-09 11:10:02 +08:00
committed by GitHub
parent 3bf09efe74
commit 6552cbf8e1
6 changed files with 85 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union
import torch
import torch.distributed as dist
@@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
def supported_precisions(self) -> List[str]:
pass
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def check_dataloader_sharding():
plugin = DPPluginWrapper()

View File

@@ -1,5 +1,8 @@
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
@@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch.cuda.empty_cache()
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
return self.weight * x
def check_torch_ddp_no_sync():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
optimizer,
criterion,
dataloader=train_dataloader)
def fwd_bwd():
output = model(batch.cuda())
loss = criterion(output)
booster.backward(loss, optimizer)
def get_grad_set_over_all_ranks():
for p in model.parameters():
# grad shape is (1, )
assert p.grad.shape == (1,)
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
dist.all_gather(grad_list, p.grad)
# get grad set of all ranks
grad_set = set([grad.item() for grad in grad_list])
# as the model only has one parameter, we can return here
return grad_set
for i, batch in enumerate(train_dataloader):
if i > 1:
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster.no_sync(model) if i == 0 else nullcontext()
with ctx:
fwd_bwd()
grad_set = get_grad_set_over_all_ranks()
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist.get_world_size() if i == 0 else 1
assert len(grad_set) == target_num_different_grad
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_ddp_plugin()
check_torch_ddp_no_sync()
@rerun_if_address_is_in_use()