ColossalAI/colossalai/tensor/colo_tensor.py
Hongxin Liu 27061426f7
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
2023-08-24 09:29:25 +08:00

102 lines
3.3 KiB
Python

from functools import lru_cache
from typing import Callable, Set
import torch
INPALCE_MAPPING = {
torch.Tensor.add_: torch.Tensor.add,
torch.Tensor.sub_: torch.Tensor.sub,
torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div
}
@lru_cache(None)
def _get_my_nowrap_functions() -> Set[Callable]:
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
def _convert(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)):
output = type(output)(_convert(o) for o in output)
return output
def _convert_output(output, func):
if func in _get_my_nowrap_functions():
return output
return _convert(output)
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
It is only used to trigger the torch function hook.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
"""
torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 parameter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
# replace the in-place function
if func in INPALCE_MAPPING:
func = INPALCE_MAPPING[func]
# set the 'inplace' kwargs to False
if 'inplace' in kwargs:
kwargs['inplace'] = False
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return _convert_output(ret, func)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data)
memo[id(self)] = tensor
return tensor