mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[feature] support no master weights option for low level zero plugin (#4816)
* [feature] support no master weights for low level zero plugin * [feature] support no master weights for low level zero plugin, remove data copy when no master weights * remove data copy and typecasting when no master weights * not load weights to cpu when using no master weights * fix grad: use fp16 grad when no master weights * only do not update working param when no master weights * fix: only do not update working param when no master weights * fix: passing params in dict format in hybrid plugin * fix: remove extra params (tp_process_group) in hybrid_parallel_plugin
This commit is contained in:
@@ -464,23 +464,23 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
optimizer,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
clip_grad_norm,
|
||||
verbose,
|
||||
reduce_bucket_size,
|
||||
communication_dtype,
|
||||
overlap_communication,
|
||||
partition_grad,
|
||||
cpu_offload,
|
||||
dp_process_group,
|
||||
forced_dtype,
|
||||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
verbose=verbose,
|
||||
reduce_bucket_size=reduce_bucket_size,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
|
@@ -262,6 +262,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -272,18 +273,19 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
self.precision = precision
|
||||
self.zero_optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=max_norm,
|
||||
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
|
Reference in New Issue
Block a user