bug fix: pass hook_list to engine (#273)

* bug fix: pass hook_list to engine

* change parameter name
This commit is contained in:
Jie Zhu
2022-03-02 14:25:52 +08:00
committed by Frank Lee
parent 5a560a060a
commit f867365aba

View File

@@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
def get_default_parser(): def get_default_parser():
@@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
@@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
) )
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler