[elixir] add elixir plugin and its unit test (#3865)

This commit is contained in:
Haichen Huang
2023-05-31 12:10:44 +08:00
committed by GitHub
parent 206280408a
commit dbb9659099
10 changed files with 386 additions and 96 deletions

View File

@@ -12,6 +12,17 @@ from .functions import postfwd_prebwd_function, prefwd_postbwd_function
from .storage import BufferStore
def always_skip(func, args, kwargs) -> bool:
if is_no_hook_op(func):
return True
if func is torch.Tensor.reshape_as:
if isinstance(args[0], HookParam):
return False
else:
return True
return False
class HookParam(OutplaceTensor, nn.Parameter):
"""HookParam is a special type of tensor that is used to triggered hooks on parameters.
HookParam adds chunk fetching before torch functions.
@@ -43,7 +54,7 @@ class HookParam(OutplaceTensor, nn.Parameter):
if kwargs is None:
kwargs = {}
if is_no_hook_op(func):
if always_skip(func, args, kwargs):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return ret