mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[booster] fix no_sync method (#3709)
* [booster] fix no_sync method * [booster] add test for ddp no_sync * [booster] fix merge * [booster] update unit test * [booster] update unit test * [booster] update unit test
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return LowLevelZeroCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
@@ -60,6 +60,13 @@ class Plugin(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(self,
|
||||
dataset: Dataset,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchDDPCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||
return model.module.no_sync()
|
||||
|
Reference in New Issue
Block a user