mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code
This commit is contained in:
@@ -21,6 +21,8 @@ class GradientStore(BaseStore):
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
"""Return list of gradient slices of a specific parameter
|
||||
|
||||
@@ -54,6 +56,8 @@ class GradientStore(BaseStore):
|
||||
else:
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
self.grad_to_param_mapping[id(grad)] = param_id
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
@@ -83,8 +87,37 @@ class GradientStore(BaseStore):
|
||||
|
||||
return grad_list
|
||||
|
||||
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
||||
"""
|
||||
Return the working gradient for the specified parameter.
|
||||
|
||||
Args:
|
||||
param_id (int): The index of the parameter.
|
||||
|
||||
Returns:
|
||||
Tensor: The the working gradient slices for the specified param_id.
|
||||
"""
|
||||
|
||||
for group in self._grads_of_params.values():
|
||||
if param_id in group.keys():
|
||||
return group[param_id][self._working_index]
|
||||
|
||||
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
||||
Args:
|
||||
grad (Tensor): the gradient slice
|
||||
|
||||
Returns:
|
||||
int: the id of a parameter which the gradient slice belongs to
|
||||
"""
|
||||
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
Reference in New Issue
Block a user