mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
set criterion as optional in colossalai initialize (#336)
This commit is contained in:
@@ -9,6 +9,8 @@ from torch.optim import Optimizer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from typing import Optional
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
class Engine:
|
||||
@@ -21,9 +23,9 @@ class Engine:
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:type optimizer: ``torch.optim.Optimizer``
|
||||
:param criterion: Loss function for calculating loss
|
||||
:type criterion: ``torch.nn.modules.loss._Loss``
|
||||
:type criterion: ``torch.nn.modules.loss._Loss``, optional
|
||||
:param gradient_handlers: A list of gradient handler used in backward
|
||||
:type gradient_handlers: list
|
||||
:type gradient_handlers: a list of ``BaseGradientHandler``, optional
|
||||
:param clip_grad_norm: The norm of gradient clipping
|
||||
:type clip_grad_norm: float, optional
|
||||
:param ophook_list: List of ophook
|
||||
@@ -31,13 +33,14 @@ class Engine:
|
||||
:param verbose: whether to display log info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
gradient_handlers: List = None,
|
||||
criterion: Optional[_Loss] = None,
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: List[BaseOpHook] = [],
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
@@ -47,7 +50,7 @@ class Engine:
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
self.training = True # default
|
||||
|
||||
# build gradient handler
|
||||
if gradient_handlers:
|
||||
@@ -55,7 +58,10 @@ class Engine:
|
||||
else:
|
||||
self._gradient_handlers = []
|
||||
|
||||
self._ophook_list = ophook_list
|
||||
if ophook_list is None:
|
||||
self._ophook_list = []
|
||||
else:
|
||||
self._ophook_list = ophook_list
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user