[zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements
This commit is contained in:
Hongxin Liu
2024-04-16 17:49:21 +08:00
committed by GitHub
parent 89049b0d89
commit 3788fefc7a
3 changed files with 56 additions and 15 deletions

View File

@@ -11,7 +11,9 @@ from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self.reset_all()
def reset_all(self) -> None:
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0