set criterion as optional in colossalai initialize (#336)

This commit is contained in:
Frank Lee
2022-03-09 11:51:22 +08:00
parent 3213554cc2
commit 6a3188167c
3 changed files with 46 additions and 39 deletions

View File

@@ -9,6 +9,8 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from typing import Optional
from colossalai.engine.gradient_handler import BaseGradientHandler
class Engine:
@@ -21,9 +23,9 @@ class Engine:
:param optimizer: Optimizer for updating the parameters
:type optimizer: ``torch.optim.Optimizer``
:param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss``
:type criterion: ``torch.nn.modules.loss._Loss``, optional
:param gradient_handlers: A list of gradient handler used in backward
:type gradient_handlers: list
:type gradient_handlers: a list of ``BaseGradientHandler``, optional
:param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional
:param ophook_list: List of ophook
@@ -31,13 +33,14 @@ class Engine:
:param verbose: whether to display log info
:type verbose: bool
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
gradient_handlers: List = None,
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: List[BaseOpHook] = [],
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True):
self._model = model
self._optimizer = optimizer
@@ -47,7 +50,7 @@ class Engine:
self._logger = get_dist_logger()
# state
self.training = True # default
self.training = True # default
# build gradient handler
if gradient_handlers:
@@ -55,7 +58,10 @@ class Engine:
else:
self._gradient_handlers = []
self._ophook_list = ophook_list
if ophook_list is None:
self._ophook_list = []
else:
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list)
@property