[fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility
This commit is contained in:
Hongxin Liu
2024-08-07 18:21:08 +08:00
committed by GitHub
parent 76ea16466f
commit ccabcf6485
6 changed files with 102 additions and 3 deletions

View File

@@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
with torch._C.DisableTorchFunction():
func = ColoParamOpHookManager.rewrite_op(func)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)

View File

@@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
def rewrite_op(self, func) -> Any:
return func
class ColoParamOpHookManager:
"""
@@ -101,6 +104,12 @@ class ColoParamOpHookManager:
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0
@staticmethod
def rewrite_op(func) -> Any:
for hook in ColoParamOpHookManager.hooks:
func = hook.rewrite_op(func)
return func
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod