mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
set criterion as optional in colossalai initialize (#336)
This commit is contained in:
@@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
@@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
optimizer: Union[Optimizer, List[Optimizer]],
|
||||
criterion: Union[_Loss, List[_Loss]],
|
||||
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
ophooks: List[BaseOpHook] = [],
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
ophooks: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
@@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
:param optimizer: Your optimizer instance
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
|
||||
:param criterion: Your criterion instance
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param train_dataloader: Dataloader for training
|
||||
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param test_dataloader: Dataloader for testing
|
||||
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param lr_scheduler: Your lr scheduler instance
|
||||
:param lr_scheduler: Your lr scheduler instance, optional
|
||||
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
@@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks
|
||||
)
|
||||
engine = Engine(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks)
|
||||
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user