mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -3,9 +3,15 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
from .colo_tensor import _convert_output
|
||||
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
|
||||
|
||||
|
||||
def is_no_hook_op(func) -> bool:
|
||||
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
|
||||
|
||||
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __init__(self,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> None:
|
||||
ColoTensor.__init__(self, data, spec)
|
||||
self._type = TensorType.MODEL
|
||||
# a list contains modules sharing this ColoParameter with others.
|
||||
self._shared_param_modules = []
|
||||
|
||||
@property
|
||||
def shared_param_modules(self):
|
||||
return self._shared_param_modules
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return super(ColoParameter, self).__repr__()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if ColoParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return ret
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return _convert_output(ret, func)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data,
|
||||
self.requires_grad,
|
||||
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
|
||||
tensor = ColoParameter(data, self.requires_grad)
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
|
Reference in New Issue
Block a user