[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
zero_ddp_model (ZeroDDP): a zero ddp model
zero_ddp_model (GeminiDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the converted torch model
@@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
assert isinstance(zero_ddp_model, GeminiDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module