diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py index af8d34904..9c9976a5d 100644 --- a/colossalai/amp/torch_amp/__init__.py +++ b/colossalai/amp/torch_amp/__init__.py @@ -3,12 +3,13 @@ from torch.optim import Optimizer from torch.nn.modules.loss import _Loss from colossalai.context import Config from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss +from typing import Optional def convert_to_torch_amp(model: nn.Module, optimizer: Optimizer, - criterion: _Loss, - amp_config: Config): + criterion: Optional[_Loss] = None, + amp_config: Optional[Config] = None): """A helper function to wrap training components with Torch AMP modules :param model: your model object @@ -16,16 +17,18 @@ def convert_to_torch_amp(model: nn.Module, :param optimizer: your optimizer object :type optimizer: :class:`torch.optim.Optimzer` :param criterion: your loss function object - :type criterion: :class:`torch.nn.modules.loss._Loss` + :type criterion: :class:`torch.nn.modules.loss._Loss`, optional :param amp_config: configuration for different amp modes - :type amp_config: :class:`colossalai.context.Config` or dict - + :type amp_config: :class:`colossalai.context.Config` or dict, optional :return: (model, optimizer, criterion) :rtype: Tuple """ model = TorchAMPModel(model) + if amp_config is None: + amp_config = dict() optimizer = TorchAMPOptimizer(optimizer, **amp_config) - criterion = TorchAMPLoss(criterion) + if criterion: + criterion = TorchAMPLoss(criterion) return model, optimizer, criterion diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 6eea2649b..699268cec 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -9,6 +9,8 @@ from torch.optim import Optimizer from colossalai.logging import get_dist_logger from torch import Tensor from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook +from typing import Optional +from colossalai.engine.gradient_handler import BaseGradientHandler class Engine: @@ -21,9 +23,9 @@ class Engine: :param optimizer: Optimizer for updating the parameters :type optimizer: ``torch.optim.Optimizer`` :param criterion: Loss function for calculating loss - :type criterion: ``torch.nn.modules.loss._Loss`` + :type criterion: ``torch.nn.modules.loss._Loss``, optional :param gradient_handlers: A list of gradient handler used in backward - :type gradient_handlers: list + :type gradient_handlers: a list of ``BaseGradientHandler``, optional :param clip_grad_norm: The norm of gradient clipping :type clip_grad_norm: float, optional :param ophook_list: List of ophook @@ -31,13 +33,14 @@ class Engine: :param verbose: whether to display log info :type verbose: bool """ + def __init__(self, model: Module, optimizer: Optimizer, - criterion: _Loss, - gradient_handlers: List = None, + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, clip_grad_norm: float = 0.0, - ophook_list: List[BaseOpHook] = [], + ophook_list: Optional[List[BaseOpHook]] = None, verbose: bool = True): self._model = model self._optimizer = optimizer @@ -47,7 +50,7 @@ class Engine: self._logger = get_dist_logger() # state - self.training = True # default + self.training = True # default # build gradient handler if gradient_handlers: @@ -55,7 +58,10 @@ class Engine: else: self._gradient_handlers = [] - self._ophook_list = ophook_list + if ophook_list is None: + self._ophook_list = [] + else: + self._ophook_list = ophook_list register_ophooks_recursively(self._model, self._ophook_list) @property diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 010cee736..d87f9658b 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) from colossalai.zero import convert_to_zero, ShardedOptimizer -from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook +from colossalai.engine.ophooks import BaseOpHook def get_default_parser(): @@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose=verbose) -def initialize(model: Union[nn.Module, List[nn.Module]], - optimizer: Union[Optimizer, List[Optimizer]], - criterion: Union[_Loss, List[_Loss]], - train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, - test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, - lr_scheduler: _LRScheduler = None, - ophooks: List[BaseOpHook] = [], - verbose: bool = True - ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: +def initialize(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: """Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. @@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]], :param optimizer: Your optimizer instance :type optimizer: :class:`torch.optim.optimizer.Optimizer` :param criterion: Your criterion instance - :type criterion: :class:`torch.nn.modules.loss._Loss` + :type criterion: :class:`torch.nn.modules.loss._Loss`, optional :param train_dataloader: Dataloader for training :type train_dataloader: :class:`torch.utils.data.DataLoader`, optional :param test_dataloader: Dataloader for testing :type test_dataloader: :class:`torch.utils.data.DataLoader`, optional - :param lr_scheduler: Your lr scheduler instance + :param lr_scheduler: Your lr scheduler instance, optional :type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional :param verbose: Whether to print logs :type verbose: bool, optional @@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]], # gradient accumulation grad_accum_size = gpc.config.get('gradient_accumulation', None) if grad_accum_size is not None: - optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model, - optimizer=optimizer, - dataloader=train_dataloader, - accumulate_size=grad_accum_size, - gradient_handlers=gradient_handlers, - lr_scheduler=lr_scheduler) + optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + accumulate_size=grad_accum_size, + gradient_handlers=gradient_handlers, + lr_scheduler=lr_scheduler) - engine = Engine( - model=model, - optimizer=optimizer, - criterion=criterion, - gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm, - ophook_list=ophooks - ) + engine = Engine(model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks) return engine, train_dataloader, test_dataloader, lr_scheduler