mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api
This commit is contained in:
33
colossalai/zero/low_level/zero_hook.py
Normal file
33
colossalai/zero/low_level/zero_hook.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import List
|
||||
|
||||
from torch._tensor import Tensor
|
||||
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
_ALL_GATHER_HANDLE = "_all_gather_handle"
|
||||
|
||||
|
||||
def wait_all_gather_handle(p):
|
||||
if hasattr(p, _ALL_GATHER_HANDLE):
|
||||
handle = getattr(p, _ALL_GATHER_HANDLE)
|
||||
handle.wait()
|
||||
delattr(p, _ALL_GATHER_HANDLE)
|
||||
|
||||
|
||||
def set_all_gather_handle(p, handle):
|
||||
setattr(p, _ALL_GATHER_HANDLE, handle)
|
||||
|
||||
|
||||
class ZeroOpHook(ColoParamOpHook):
|
||||
def pre_forward(self, params: List[Tensor]) -> None:
|
||||
for p in params:
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def post_forward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
Reference in New Issue
Block a user