bug fix: pass hook_list to engine (#273)

* bug fix: pass hook_list to engine

* change parameter name
This commit is contained in:
Jie Zhu 2022-03-02 14:25:52 +08:00 committed by GitHub
parent 3280869358
commit 3b64dcc439
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
def get_default_parser(): def get_default_parser():
@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
) )
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler