mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
from functools import partial
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.fx
|
|
import torch.multiprocessing as mp
|
|
|
|
try:
|
|
from fastfold.model.nn.evoformer import ExtraMSABlock
|
|
HAS_REPO = True
|
|
except:
|
|
HAS_REPO = False
|
|
|
|
import colossalai
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
|
from colossalai.fx.graph_module import ColoGraphModule
|
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
from colossalai.utils import free_port
|
|
|
|
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
|
from colossalai.fx.profiler import MetaTensor
|
|
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
|
|
|
|
|
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
|
# for memory test
|
|
# model = model.cuda()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
|
# with torch.no_grad():
|
|
# node1 = node.clone()
|
|
# pair1 = pair.clone()
|
|
# node_mask1 = node_mask.clone()
|
|
# pair_mask1 = pair_mask.clone()
|
|
# gm(node1, pair1, node_mask1, pair_mask1)
|
|
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
|
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
|
|
|
# test forward
|
|
model = model.cuda()
|
|
with torch.no_grad():
|
|
non_fx_out = model(node, pair, node_mask, pair_mask)
|
|
fx_out = gm(node, pair, node_mask, pair_mask)
|
|
|
|
assert torch.allclose(non_fx_out[0], fx_out[0],
|
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
|
torch.abs(non_fx_out[0] - fx_out[0]))
|
|
assert torch.allclose(non_fx_out[1], fx_out[1],
|
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
|
torch.abs(non_fx_out[1] - fx_out[1]))
|
|
|
|
|
|
def _build_openfold():
|
|
model = ExtraMSABlock(
|
|
c_m=256,
|
|
c_z=128,
|
|
c_hidden_msa_att=32,
|
|
c_hidden_opm=32,
|
|
c_hidden_mul=128,
|
|
c_hidden_pair_att=32,
|
|
no_heads_msa=8,
|
|
no_heads_pair=4,
|
|
transition_n=4,
|
|
msa_dropout=0.15,
|
|
pair_dropout=0.15,
|
|
inf=1e4,
|
|
eps=1e-4,
|
|
ckpt=False,
|
|
is_multimer=False,
|
|
).eval().cuda()
|
|
return model
|
|
|
|
|
|
def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory):
|
|
# launch colossalai
|
|
colossalai.launch(
|
|
config={},
|
|
rank=rank,
|
|
world_size=1,
|
|
host="localhost",
|
|
port=free_port(),
|
|
backend="nccl",
|
|
)
|
|
|
|
# build model and input
|
|
model = _build_openfold()
|
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
|
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
|
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
|
|
|
# trace the meta graph and setup codegen
|
|
meta_graph = symbolic_trace(
|
|
model,
|
|
meta_args={
|
|
"m": node.to(torch.device("meta")),
|
|
"z": pair.to(torch.device("meta")),
|
|
"msa_mask": node_mask.to(torch.device("meta")),
|
|
"pair_mask": pair_mask.to(torch.device("meta")),
|
|
},
|
|
concrete_args={
|
|
"chunk_size": None,
|
|
"_chunk_logits": 1024,
|
|
},
|
|
)
|
|
interp = MetaInfoProp(meta_graph)
|
|
interp.propagate(
|
|
MetaTensor(node, fake_device="cuda:0"),
|
|
MetaTensor(pair, fake_device="cuda:0"),
|
|
MetaTensor(node_mask, fake_device="cuda:0"),
|
|
MetaTensor(pair_mask, fake_device="cuda:0"),
|
|
)
|
|
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
|
|
|
|
# trace and recompile
|
|
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
|
graph = ColoTracer().trace(
|
|
model,
|
|
meta_args={
|
|
"m": node.to(torch.device("meta")),
|
|
"z": pair.to(torch.device("meta")),
|
|
"msa_mask": node_mask.to(torch.device("meta")),
|
|
"pair_mask": pair_mask.to(torch.device("meta")),
|
|
},
|
|
concrete_args={
|
|
"chunk_size": None,
|
|
"_chunk_logits": 1024,
|
|
},
|
|
)
|
|
graph.set_codegen(codegen)
|
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
|
gm.recompile()
|
|
|
|
# assert we have inserted chunk
|
|
code = graph.python_code("self").src
|
|
# print(code)
|
|
assert "chunk_result = None; chunk_size = None;" in code
|
|
|
|
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
|
gpc.destroy()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
|
reason="torch version is lower than 1.12.0",
|
|
)
|
|
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
|
@pytest.mark.parametrize("msa_len", [32])
|
|
@pytest.mark.parametrize("pair_len", [64])
|
|
def test_extramsa_codegen(msa_len, pair_len, max_memory):
|
|
run_func = partial(
|
|
_test_extramsa_codegen,
|
|
msa_len=msa_len,
|
|
pair_len=pair_len,
|
|
max_memory=max_memory,
|
|
)
|
|
mp.spawn(run_func, nprocs=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_extramsa_codegen(0, 32, 64, None)
|