mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -3,9 +3,7 @@ from contextlib import contextmanager
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class ColoParamOpHook(ABC):
|
||||
@@ -82,26 +80,18 @@ class ColoParamOpHookManager:
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
grad_args, rear_args = _get_grad_args(*args)
|
||||
colo_info = _get_colo_tensors_info(*grad_args)
|
||||
rets = PreFwdPostBwd.apply(params, *grad_args)
|
||||
update_args = _update_colo_tensors(colo_info, *rets)
|
||||
if rear_args is None:
|
||||
return update_args
|
||||
else:
|
||||
arg_zero = (tuple(update_args),)
|
||||
return arg_zero + rear_args
|
||||
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
|
||||
# if one of the input requires grad, all the output will be treated as requires grad
|
||||
# and will have grad fn even the corresponding input does not require grad
|
||||
# we have to extract tensors requiring grad into flat list and then merge them back
|
||||
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
|
||||
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
|
||||
return _merge_args(new_grad_args, other_args, grad_flags, spec)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
colo_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
res = _update_colo_tensors(colo_info, ret)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
return PostFwdPreBwd.apply(params, arg)
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_grad_tensor(obj) -> bool:
|
||||
if isinstance(obj, tuple) or isinstance(obj, list):
|
||||
for x in obj:
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
elif isinstance(obj, dict):
|
||||
for x in obj.values():
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return _is_grad_tensor(obj)
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# if there is no grad tensors, do nothing
|
||||
if not _has_grad_tensor(args):
|
||||
return args, None
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
return args, None
|
||||
# otherwise, the first argument should be a tuple of grad tensors
|
||||
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
||||
arg_zero = args[0]
|
||||
if not isinstance(arg_zero, tuple):
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
||||
check_grad_flag = False
|
||||
for obj in arg_zero:
|
||||
check_grad_flag |= _is_grad_tensor(obj)
|
||||
if not check_grad_flag:
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
||||
return arg_zero, args[1:]
|
||||
|
||||
|
||||
def _get_colo_tensors_info(*args) -> list:
|
||||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
|
||||
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
|
||||
flat_args, spec = tree_flatten(args)
|
||||
grad_args = []
|
||||
other_args = []
|
||||
grad_flags = []
|
||||
for arg in flat_args:
|
||||
flag = _is_grad_tensor(arg)
|
||||
grad_flags.append(flag)
|
||||
if flag:
|
||||
grad_args.append(arg)
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
other_args.append(arg)
|
||||
assert len(grad_args) > 0
|
||||
return grad_args, other_args, grad_flags, spec
|
||||
|
||||
|
||||
def _update_colo_tensors(info, *args) -> list:
|
||||
ret = []
|
||||
for t_info, arg in zip(info, args):
|
||||
if t_info is not None:
|
||||
t_cls, spec = t_info
|
||||
arg = t_cls.from_torch_tensor(arg, spec=spec)
|
||||
ret.append(arg)
|
||||
return ret
|
||||
def _merge_args(grad_args, other_args, grad_flags, spec):
|
||||
grad_iter = iter(grad_args)
|
||||
other_iter = iter(other_args)
|
||||
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
|
||||
return tree_unflatten(flat_args, spec)
|
||||
|
Reference in New Issue
Block a user