mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[zero] refactory ShardedOptimV2 init method (#416)
This commit is contained in:
parent
e79ea44247
commit
23ba3fc450
@ -29,7 +29,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
sharded_model: ShardedModelV2,
|
sharded_model: ShardedModelV2,
|
||||||
optimizer_class: Type[Optimizer],
|
optimizer_class: Type[Optimizer],
|
||||||
shard_strategy: BaseShardStrategy,
|
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
min_scale: float = 1,
|
min_scale: float = 1,
|
||||||
@ -42,20 +41,43 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
mp_process_group: Optional[ProcessGroup] = None,
|
mp_process_group: Optional[ProcessGroup] = None,
|
||||||
**defaults: Any) -> None:
|
**defaults: Any) -> None:
|
||||||
"""
|
"""
|
||||||
:param sharded_model: A sharded model initialized by class ShardedModelV2
|
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||||
|
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||||
:type sharded_model: sharded_model
|
:type sharded_model: sharded_model
|
||||||
|
|
||||||
:param optimizer_class: A type of Optimizer
|
:param optimizer_class: A class type of Optimizer
|
||||||
:type optimizer_class: Type[Optimizer]
|
:type optimizer_class: Type[Optimizer]
|
||||||
|
|
||||||
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
|
|
||||||
:type shard_strategy: BaseShardStrategy
|
|
||||||
|
|
||||||
:param cpu_offload: is offloading the optimizer states to CPU.
|
:param cpu_offload: is offloading the optimizer states to CPU.
|
||||||
:type cpu_offload: bool
|
:type cpu_offload: bool
|
||||||
|
|
||||||
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
|
:param initial_scale: initial scale used by DynamicGradScaler
|
||||||
:type shard_strategy: BaseShardStrategy
|
:type initial_scale: float
|
||||||
|
|
||||||
|
:param min_scale: min scale used by DynamicGradScaler
|
||||||
|
:type min_scale: float
|
||||||
|
|
||||||
|
:param growth_factor: growth_factor used by DynamicGradScaler
|
||||||
|
:type growth_factor: float
|
||||||
|
|
||||||
|
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
||||||
|
:type backoff_factor: float
|
||||||
|
|
||||||
|
:param growth_interval: growth_interval used by DynamicGradScaler
|
||||||
|
:type growth_interval: float
|
||||||
|
|
||||||
|
:param hysteresis: hysteresis used by DynamicGradScaler
|
||||||
|
:type hysteresis: float
|
||||||
|
|
||||||
|
:param max_scale: max_scale used by DynamicGradScaler
|
||||||
|
:type max_scale: float
|
||||||
|
|
||||||
|
:param dp_process_group: data paralle process group
|
||||||
|
:type dp_process_group: Optional[ProcessGroup]
|
||||||
|
|
||||||
|
:param mp_process_group: model paralle process group
|
||||||
|
:type mp_process_group: Optional[ProcessGroup]
|
||||||
|
|
||||||
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
|
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
|
||||||
:type defaults: dict()
|
:type defaults: dict()
|
||||||
"""
|
"""
|
||||||
@ -67,7 +89,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
|
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
|
||||||
|
|
||||||
super().__init__(self.optimizer)
|
super().__init__(self.optimizer)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = sharded_model.shard_strategy
|
||||||
self.model: ShardedModelV2 = sharded_model
|
self.model: ShardedModelV2 = sharded_model
|
||||||
if cpu_offload and not sharded_model.cpu_offload:
|
if cpu_offload and not sharded_model.cpu_offload:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -52,7 +52,6 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
|
|||||||
optim = optimizer_class(model.parameters(), lr=lr)
|
optim = optimizer_class(model.parameters(), lr=lr)
|
||||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
shard_strategy,
|
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
initial_scale=2**5,
|
initial_scale=2**5,
|
||||||
lr=lr)
|
lr=lr)
|
||||||
|
@ -59,12 +59,7 @@ def run_dist(rank, world_size, port, shard_strategy):
|
|||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
optim = Adam(model.parameters(), lr=1e-3)
|
optim = Adam(model.parameters(), lr=1e-3)
|
||||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3)
|
||||||
CPUAdam,
|
|
||||||
shard_strategy,
|
|
||||||
initial_scale=2**5,
|
|
||||||
cpu_offload=True,
|
|
||||||
lr=1e-3)
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user