mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update import
This commit is contained in:
88
colossalai/zero/low_level/bookkeeping/gradient_store.py
Normal file
88
colossalai/zero/low_level/bookkeeping/gradient_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def append_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
if group_id not in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Add an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor_idx: The index of a tensor in the list of averaged gradients
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor_idx: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
def reset_all_average_gradients(self) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
"""
|
||||
self._averaged_gradients = dict()
|
Reference in New Issue
Block a user