mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-20 18:30:43 +00:00
[zero] trivial zero optimizer refactoring (#2869)
* Fix mionr grad store interface * Apply lint
This commit is contained in:
parent
dbc01b9c04
commit
7b13f7db18
@ -6,6 +6,7 @@ from .base_store import BaseStore
|
|||||||
|
|
||||||
|
|
||||||
class GradientStore(BaseStore):
|
class GradientStore(BaseStore):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
super().__init__(*args)
|
super().__init__(*args)
|
||||||
# bookkeeping data structures
|
# bookkeeping data structures
|
||||||
@ -56,9 +57,7 @@ class GradientStore(BaseStore):
|
|||||||
else:
|
else:
|
||||||
self._averaged_gradients[group_id] = [tensor]
|
self._averaged_gradients[group_id] = [tensor]
|
||||||
|
|
||||||
def add_average_gradient_by_group(
|
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
||||||
self, group_id: int, tensor_idx: int, tensor: Tensor
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Add an average gradient to the list of averaged gradients of a parameter group
|
Add an average gradient to the list of averaged gradients of a parameter group
|
||||||
|
|
||||||
@ -81,3 +80,9 @@ class GradientStore(BaseStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._averaged_gradients[group_id] = []
|
self._averaged_gradients[group_id] = []
|
||||||
|
|
||||||
|
def reset_all_average_gradients(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||||
|
"""
|
||||||
|
self._averaged_gradients = dict()
|
||||||
|
@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||||
:type set_to_none: bool
|
:type set_to_none: bool
|
||||||
"""
|
"""
|
||||||
for group_id, param_group in self._fp16_param_groups.items():
|
for _, param_group in self._fp16_param_groups.items():
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
if set_to_none:
|
if set_to_none:
|
||||||
param.grad = None
|
param.grad = None
|
||||||
@ -438,7 +438,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
# update loss scale if overflow occurs
|
# update loss scale if overflow occurs
|
||||||
if found_inf:
|
if found_inf:
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store.reset_all_average_gradients()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
# compute norm
|
# compute norm
|
||||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
|
||||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||||
rank=self._local_rank),
|
rank=self._local_rank),
|
||||||
dp_group=self._dp_torch_group,
|
dp_group=self._dp_torch_group,
|
||||||
@ -469,8 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||||
@ -546,28 +545,22 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
def _sync_grad(self):
|
def _sync_grad(self):
|
||||||
# update param already reduced flag
|
# update param already reduced flag
|
||||||
reduction_states = self._param_store.get_param_reduction_states()
|
reduction_states = self._param_store.get_param_reduction_states()
|
||||||
for tensor, state in reduction_states.items():
|
for tensor, _ in reduction_states.items():
|
||||||
reduction_states[tensor] = False
|
reduction_states[tensor] = False
|
||||||
|
|
||||||
# accumulate gradient
|
# accumulate gradient
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
||||||
|
|
||||||
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(
|
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||||
group_id
|
|
||||||
)
|
|
||||||
|
|
||||||
param_idx = 0
|
param_idx = 0
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
if len(avg_gradients_group) == param_idx:
|
if len(avg_gradients_group) == param_idx:
|
||||||
self._grad_store.append_average_gradient_by_group(
|
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
|
||||||
group_id, param.grad
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self._grad_store.add_average_gradient_by_group(
|
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
|
||||||
group_id, param_idx, param.grad
|
|
||||||
)
|
|
||||||
param_idx += 1
|
param_idx += 1
|
||||||
|
|
||||||
# the gradients needed are stored in the avg_gradients buffer
|
# the gradients needed are stored in the avg_gradients buffer
|
||||||
|
Loading…
Reference in New Issue
Block a user