mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,7 @@ INPALCE_MAPPING = {
|
||||
torch.Tensor.add_: torch.Tensor.add,
|
||||
torch.Tensor.sub_: torch.Tensor.sub,
|
||||
torch.Tensor.mul_: torch.Tensor.mul,
|
||||
torch.Tensor.div_: torch.Tensor.div
|
||||
torch.Tensor.div_: torch.Tensor.div,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
@@ -37,17 +37,18 @@ def _convert_output(output, func):
|
||||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
"""Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
|
||||
It is only used to trigger the torch function hook.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
"""
|
||||
torch_major = int(torch.__version__.split('.')[0])
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
|
||||
torch_major = int(torch.__version__.split(".")[0])
|
||||
torch_minor = int(torch.__version__.split(".")[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor) -> "ColoTensor":
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
@@ -74,7 +75,7 @@ class ColoTensor(torch.Tensor):
|
||||
# we have to capture the `backward` function
|
||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||
if func is torch.Tensor.backward:
|
||||
assert len(args) == 1 # only has 1 parameter
|
||||
assert len(args) == 1 # only has 1 parameter
|
||||
backward_tensor = torch.Tensor(args[0])
|
||||
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||
return backward_tensor.backward(**tensor_kwargs)
|
||||
@@ -83,8 +84,8 @@ class ColoTensor(torch.Tensor):
|
||||
if func in INPALCE_MAPPING:
|
||||
func = INPALCE_MAPPING[func]
|
||||
# set the 'inplace' kwargs to False
|
||||
if 'inplace' in kwargs:
|
||||
kwargs['inplace'] = False
|
||||
if "inplace" in kwargs:
|
||||
kwargs["inplace"] = False
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user