mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 20:43:19 +00:00
[elixir] add elixir plugin and its unit test (#3865)
This commit is contained in:
@@ -12,6 +12,17 @@ from .functions import postfwd_prebwd_function, prefwd_postbwd_function
|
||||
from .storage import BufferStore
|
||||
|
||||
|
||||
def always_skip(func, args, kwargs) -> bool:
|
||||
if is_no_hook_op(func):
|
||||
return True
|
||||
if func is torch.Tensor.reshape_as:
|
||||
if isinstance(args[0], HookParam):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class HookParam(OutplaceTensor, nn.Parameter):
|
||||
"""HookParam is a special type of tensor that is used to triggered hooks on parameters.
|
||||
HookParam adds chunk fetching before torch functions.
|
||||
@@ -43,7 +54,7 @@ class HookParam(OutplaceTensor, nn.Parameter):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if is_no_hook_op(func):
|
||||
if always_skip(func, args, kwargs):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user