set criterion as optional in colossalai initialize (#336)

This commit is contained in:
Frank Lee 2022-03-09 11:51:22 +08:00
parent 3213554cc2
commit 6a3188167c
3 changed files with 46 additions and 39 deletions

View File

@ -3,12 +3,13 @@ from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.context import Config from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
from typing import Optional
def convert_to_torch_amp(model: nn.Module, def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
criterion: _Loss, criterion: Optional[_Loss] = None,
amp_config: Config): amp_config: Optional[Config] = None):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with Torch AMP modules
:param model: your model object :param model: your model object
@ -16,16 +17,18 @@ def convert_to_torch_amp(model: nn.Module,
:param optimizer: your optimizer object :param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer` :type optimizer: :class:`torch.optim.Optimzer`
:param criterion: your loss function object :param criterion: your loss function object
:type criterion: :class:`torch.nn.modules.loss._Loss` :type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param amp_config: configuration for different amp modes :param amp_config: configuration for different amp modes
:type amp_config: :class:`colossalai.context.Config` or dict :type amp_config: :class:`colossalai.context.Config` or dict, optional
:return: (model, optimizer, criterion) :return: (model, optimizer, criterion)
:rtype: Tuple :rtype: Tuple
""" """
model = TorchAMPModel(model) model = TorchAMPModel(model)
if amp_config is None:
amp_config = dict()
optimizer = TorchAMPOptimizer(optimizer, **amp_config) optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion) if criterion:
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion return model, optimizer, criterion

View File

@ -9,6 +9,8 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from torch import Tensor from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from typing import Optional
from colossalai.engine.gradient_handler import BaseGradientHandler
class Engine: class Engine:
@ -21,9 +23,9 @@ class Engine:
:param optimizer: Optimizer for updating the parameters :param optimizer: Optimizer for updating the parameters
:type optimizer: ``torch.optim.Optimizer`` :type optimizer: ``torch.optim.Optimizer``
:param criterion: Loss function for calculating loss :param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss`` :type criterion: ``torch.nn.modules.loss._Loss``, optional
:param gradient_handlers: A list of gradient handler used in backward :param gradient_handlers: A list of gradient handler used in backward
:type gradient_handlers: list :type gradient_handlers: a list of ``BaseGradientHandler``, optional
:param clip_grad_norm: The norm of gradient clipping :param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional :type clip_grad_norm: float, optional
:param ophook_list: List of ophook :param ophook_list: List of ophook
@ -31,13 +33,14 @@ class Engine:
:param verbose: whether to display log info :param verbose: whether to display log info
:type verbose: bool :type verbose: bool
""" """
def __init__(self, def __init__(self,
model: Module, model: Module,
optimizer: Optimizer, optimizer: Optimizer,
criterion: _Loss, criterion: Optional[_Loss] = None,
gradient_handlers: List = None, gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0, clip_grad_norm: float = 0.0,
ophook_list: List[BaseOpHook] = [], ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True): verbose: bool = True):
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
@ -47,7 +50,7 @@ class Engine:
self._logger = get_dist_logger() self._logger = get_dist_logger()
# state # state
self.training = True # default self.training = True # default
# build gradient handler # build gradient handler
if gradient_handlers: if gradient_handlers:
@ -55,7 +58,10 @@ class Engine:
else: else:
self._gradient_handlers = [] self._gradient_handlers = []
self._ophook_list = ophook_list if ophook_list is None:
self._ophook_list = []
else:
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list) register_ophooks_recursively(self._model, self._ophook_list)
@property @property

View File

@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook from colossalai.engine.ophooks import BaseOpHook
def get_default_parser(): def get_default_parser():
@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose) verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]], def initialize(model: nn.Module,
optimizer: Union[Optimizer, List[Optimizer]], optimizer: Optimizer,
criterion: Union[_Loss, List[_Loss]], criterion: Optional[_Loss] = None,
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Iterable] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: Optional[_LRScheduler] = None,
ophooks: List[BaseOpHook] = [], ophooks: Optional[List[BaseOpHook]] = None,
verbose: bool = True verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config. loaded into gpc.config.
@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
:param optimizer: Your optimizer instance :param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer` :type optimizer: :class:`torch.optim.optimizer.Optimizer`
:param criterion: Your criterion instance :param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss` :type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training :param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional :type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing :param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional :type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance :param lr_scheduler: Your lr scheduler instance, optional
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional :type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs :param verbose: Whether to print logs
:type verbose: bool, optional :type verbose: bool, optional
@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# gradient accumulation # gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None) grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None: if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model, optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
optimizer=optimizer, model=model,
dataloader=train_dataloader, optimizer=optimizer,
accumulate_size=grad_accum_size, dataloader=train_dataloader,
gradient_handlers=gradient_handlers, accumulate_size=grad_accum_size,
lr_scheduler=lr_scheduler) gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
engine = Engine( engine = Engine(model=model,
model=model, optimizer=optimizer,
optimizer=optimizer, criterion=criterion,
criterion=criterion, gradient_handlers=gradient_handlers,
gradient_handlers=gradient_handlers, clip_grad_norm=clip_grad_norm,
clip_grad_norm=clip_grad_norm, ophook_list=ophooks)
ophook_list=ophooks
)
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler