[zero] refactory ShardedOptimV2 init method (#416)

This commit is contained in:
Jiarui Fang
2022-03-15 10:45:55 +08:00
committed by GitHub
parent e79ea44247
commit 23ba3fc450
3 changed files with 32 additions and 16 deletions

View File

@@ -29,7 +29,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy,
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
@@ -42,20 +41,43 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:param optimizer_class: A class type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param initial_scale: initial scale used by DynamicGradScaler
:type initial_scale: float
:param min_scale: min scale used by DynamicGradScaler
:type min_scale: float
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
@@ -67,7 +89,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload:
raise RuntimeError(