[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -3,9 +3,7 @@ from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
class ColoParamOpHook(ABC):
@@ -82,26 +80,18 @@ class ColoParamOpHookManager:
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params)
grad_args, rear_args = _get_grad_args(*args)
colo_info = _get_colo_tensors_info(*grad_args)
rets = PreFwdPostBwd.apply(params, *grad_args)
update_args = _update_colo_tensors(colo_info, *rets)
if rear_args is None:
return update_args
else:
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
# if one of the input requires grad, all the output will be treated as requires grad
# and will have grad fn even the corresponding input does not require grad
# we have to extract tensors requiring grad into flat list and then merge them back
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
colo_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
return PostFwdPreBwd.apply(params, arg)
@staticmethod
def has_hook() -> bool:
@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first argument should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
flat_args, spec = tree_flatten(args)
grad_args = []
other_args = []
grad_flags = []
for arg in flat_args:
flag = _is_grad_tensor(arg)
grad_flags.append(flag)
if flag:
grad_args.append(arg)
else:
info.append(None)
return info
other_args.append(arg)
assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
def _update_colo_tensors(info, *args) -> list:
ret = []
for t_info, arg in zip(info, args):
if t_info is not None:
t_cls, spec = t_info
arg = t_cls.from_torch_tensor(arg, spec=spec)
ret.append(arg)
return ret
def _merge_args(grad_args, other_args, grad_flags, spec):
grad_iter = iter(grad_args)
other_iter = iter(other_args)
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
return tree_unflatten(flat_args, spec)