mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility
This commit is contained in:
@@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
with torch._C.DisableTorchFunction():
|
||||
func = ColoParamOpHookManager.rewrite_op(func)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
|
@@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
|
||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func) -> Any:
|
||||
return func
|
||||
|
||||
|
||||
class ColoParamOpHookManager:
|
||||
"""
|
||||
@@ -101,6 +104,12 @@ class ColoParamOpHookManager:
|
||||
def has_hook() -> bool:
|
||||
return len(ColoParamOpHookManager.hooks) > 0
|
||||
|
||||
@staticmethod
|
||||
def rewrite_op(func) -> Any:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
func = hook.rewrite_op(func)
|
||||
return func
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user