[misc] add verbose arg for zero and op builder (#3552)

* [misc] add print verbose

* [gemini] add print verbose

* [zero] add print verbose for low level

* [misc] add print verbose for op builder
This commit is contained in:
Hongxin Liu
2023-04-17 11:25:35 +08:00
committed by GitHub
parent 4341f5e8e6
commit 173dad0562
8 changed files with 55 additions and 28 deletions

View File

@@ -65,9 +65,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
@@ -76,8 +76,17 @@ class GeminiModel(ModelWrapper):
class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
@@ -138,6 +147,7 @@ class GeminiPlugin(Plugin):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""
def __init__(
@@ -161,6 +171,7 @@ class GeminiPlugin(Plugin):
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
verbose: bool = False,
) -> None:
assert dist.is_initialized(
@@ -188,6 +199,7 @@ class GeminiPlugin(Plugin):
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.verbose = verbose
def support_no_sync(self) -> bool:
return False
@@ -275,10 +287,11 @@ class GeminiPlugin(Plugin):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config)
model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs)
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler