set criterion as optional in colossalai initialize (#336)

This commit is contained in:
Frank Lee
2022-03-09 11:51:22 +08:00
parent 3213554cc2
commit 6a3188167c
3 changed files with 46 additions and 39 deletions

View File

@@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.ophooks import BaseOpHook
def get_default_parser():
@@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer: Union[Optimizer, List[Optimizer]],
criterion: Union[_Loss, List[_Loss]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
def initialize(model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
ophooks: Optional[List[BaseOpHook]] = None,
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
@@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance
:param lr_scheduler: Your lr scheduler instance, optional
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
@@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
)
engine = Engine(model=model,
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks)
return engine, train_dataloader, test_dataloader, lr_scheduler