[booster] fix no_sync method (#3709)

* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
This commit is contained in:
Hongxin Liu
2023-05-09 11:10:02 +08:00
committed by GitHub
parent 3bf09efe74
commit 6552cbf8e1
6 changed files with 85 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()