From 7a58dc5ad244f68d162e34a269a4a7c96f9896ab Mon Sep 17 00:00:00 2001
From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com>
Date: Fri, 27 Jan 2023 09:52:21 +0800
Subject: [PATCH] Update metainfo patch branch (#2517)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db01b501e95b66e91be9fe27b58d2e58.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl
Co-authored-by: eric8607242
Co-authored-by: HELSON
Co-authored-by: Frank Lee
Co-authored-by: Haofan Wang
Co-authored-by: Jiarui Fang
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217
Co-authored-by: Ziyue Jiang
Co-authored-by: Ziyue Jiang
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス
---
.bdist.json | 24 +
.compatibility | 3 +
.github/workflows/README.md | 149 ++
.github/workflows/auto_compatibility_test.yml | 74 +
.github/workflows/auto_example_check.yml | 143 ++
.github/workflows/auto_release_bdist.yml | 70 +
.github/workflows/build.yml | 29 +-
...rigger_examples_check_and_weekly_check.yml | 119 --
...st.yml => dispatch_compatibility_test.yml} | 2 +-
...example.yml => dispatch_example_check.yml} | 44 +-
.../workflows/draft_github_release_post.yml | 3 +-
.github/workflows/pre_commit.yml | 71 +
.github/workflows/release_docker.yml | 29 +-
.github/workflows/release_nightly.yml | 86 +-
.../workflows/report_precommit_failure.yml | 67 +
.github/workflows/report_test_coverage.yml | 74 +
.../example_checks/check_dispatch_inputs.py | 27 +
.../check_example_weekly.py} | 9 +-
.../detect_changed_example.py} | 11 +-
.../workflows/scripts/input_check_example.py | 23 -
.github/workflows/translate_comment.yml | 18 +
.gitignore | 4 +
README-zh-Hans.md | 55 +-
README.md | 39 +-
.../passes/runtime_apply_pass.py | 33 +
.../passes/runtime_preparation_pass.py | 2 +
.../auto_parallel/tensor_shard/initialize.py | 72 +-
.../tensor_shard/node_handler/__init__.py | 3 +-
.../binary_elementwise_handler.py | 27 +-
.../tensor_shard/node_handler/node_handler.py | 18 +
.../tensor_shard/node_handler/option.py | 17 +
colossalai/autochunk/autochunk_codegen.py | 523 ++++++
colossalai/autochunk/estimate_memory.py | 323 ++++
colossalai/autochunk/reorder_graph.py | 117 ++
colossalai/autochunk/search_chunk.py | 319 ++++
colossalai/autochunk/select_chunk.py | 224 +++
colossalai/autochunk/trace_flow.py | 445 +++++
colossalai/autochunk/trace_indice.py | 703 ++++++++
colossalai/autochunk/utils.py | 132 ++
colossalai/cli/launcher/hostinfo.py | 9 +-
colossalai/cli/launcher/multinode_runner.py | 12 +-
colossalai/cli/launcher/run.py | 54 +-
colossalai/device/alpha_beta_profiler.py | 4 +-
colossalai/device/device_mesh.py | 30 +-
colossalai/fx/graph_module.py | 22 +-
.../fx/passes/adding_split_node_pass.py | 36 +
colossalai/fx/passes/meta_info_prop.py | 3 +-
colossalai/fx/profiler/opcount.py | 5 +-
colossalai/fx/profiler/tensor.py | 45 +-
colossalai/fx/tracer/_symbolic_trace.py | 3 +-
colossalai/fx/tracer/experimental.py | 42 +-
colossalai/gemini/chunk/search_utils.py | 9 +-
colossalai/gemini/chunk/utils.py | 15 +-
colossalai/gemini/gemini_mgr.py | 18 +-
colossalai/initialize.py | 29 +-
.../kernel/cuda_native/scaled_softmax.py | 17 +-
colossalai/nn/optimizer/cpu_adam.py | 2 +-
colossalai/nn/optimizer/fused_adam.py | 3 +-
colossalai/nn/optimizer/fused_lamb.py | 3 +-
colossalai/nn/optimizer/fused_sgd.py | 3 +-
colossalai/nn/optimizer/hybrid_adam.py | 2 +-
colossalai/nn/optimizer/zero_optimizer.py | 22 +-
colossalai/nn/parallel/data_parallel.py | 104 +-
colossalai/nn/parallel/gemini_parallel.py | 3 +-
colossalai/nn/parallel/utils.py | 25 +-
colossalai/pipeline/rpc/_pipeline_base.py | 4 +-
colossalai/pipeline/rpc/_pipeline_schedule.py | 3 -
colossalai/utils/__init__.py | 42 +-
colossalai/utils/common.py | 8 +-
colossalai/utils/model/experimental.py | 440 +++++
colossalai/zero/sharded_optim/_utils.py | 23 +-
.../sharded_optim/bookkeeping/base_store.py | 10 +-
.../sharded_optim/bookkeeping/bucket_store.py | 19 +-
.../bookkeeping/parameter_store.py | 5 +-
.../zero/sharded_optim/low_level_optim.py | 285 +--
colossalai/zero/utils/gemini_hook.py | 5 +-
docker/Dockerfile | 5 +-
examples/README.md | 48 +-
examples/images/diffusion/README.md | 14 +-
.../diffusion/test_ci.sh} | 0
.../dreambooth/test_ci.sh} | 0
.../dreambooth/train_dreambooth_colossalai.py | 28 +-
examples/images/vit/configs/vit_1d_tp2_ci.py | 32 +
examples/images/vit/requirements.txt | 6 +
examples/images/vit/test_ci.sh | 9 +
examples/images/vit/train.py | 25 +-
examples/images/vit/vit.py | 23 +-
examples/language/gpt/README.md | 17 +-
.../auto_parallel/auto_parallel_with_gpt.py | 20 +-
.../saved_solution/solution_12_layers.pt | Bin 0 -> 1903 bytes
.../saved_solution/solution_1_layers.pt | Bin 0 -> 559 bytes
.../saved_solution/solution_4_layers.pt | Bin 0 -> 943 bytes
.../pipeline_parallel/requirements.txt | 2 +
.../pipeline_parallel/train_gpt_pp.py | 2 +-
.../language/gpt/gemini/benchmark_gemini.sh | 30 +-
.../language/gpt/gemini/commons/model_zoo.py | 12 +
examples/language/gpt/gemini/requirements.txt | 2 +
examples/language/gpt/gemini/run_gemini.sh | 11 +-
examples/language/gpt/gemini/test_ci.sh | 35 +
.../language/gpt/gemini/train_gpt_demo.py | 37 +-
examples/language/gpt/requirements.txt | 1 +
examples/language/gpt/test_ci.sh | 18 +-
examples/language/gpt/titans/LICENSE | 201 +++
examples/language/gpt/titans/README.md | 48 +
.../titans/configs/gpt2_small_zero3_pp1d.py | 31 +
.../gpt/titans/configs/gpt3_zero3_pp1d.py | 31 +
.../language/gpt/titans/dataset/webtext.py | 43 +
.../language/gpt/titans/model/__init__.py | 3 +
examples/language/gpt/titans/model/embed.py | 599 +++++++
examples/language/gpt/titans/model/gpt1d.py | 349 ++++
.../gpt/titans/model/pipeline_gpt1d.py | 322 ++++
examples/language/gpt/titans/requirements.txt | 4 +
examples/language/gpt/titans/run.sh | 3 +
examples/language/gpt/titans/test_ci.sh | 1 +
examples/language/gpt/titans/train_gpt.py | 113 ++
examples/language/opt/requirements.txt | 2 +
examples/language/opt/test_ci.sh | 4 +
examples/language/palm/run.sh | 2 +-
examples/language/palm/test_ci.sh | 9 +
examples/language/palm/train.py | 100 +-
examples/tutorial/README.md | 26 -
examples/tutorial/auto_parallel/README.md | 113 +-
.../auto_parallel_with_resnet.py | 144 +-
examples/tutorial/auto_parallel/config.py | 4 +-
.../tutorial/auto_parallel/requirements.txt | 9 +-
.../setup.py | 6 +-
examples/tutorial/auto_parallel/test_ci.sh | 6 +
examples/tutorial/hybrid_parallel/README.md | 55 +-
examples/tutorial/hybrid_parallel/config.py | 10 +-
.../tutorial/hybrid_parallel/requirements.txt | 5 +-
examples/tutorial/hybrid_parallel/test_ci.sh | 5 +
examples/tutorial/hybrid_parallel/train.py | 24 +-
.../tutorial/large_batch_optimizer/README.md | 42 +-
.../tutorial/large_batch_optimizer/config.py | 26 +-
.../large_batch_optimizer/requirements.txt | 5 +-
.../tutorial/large_batch_optimizer/test_ci.sh | 8 +
.../tutorial/large_batch_optimizer/train.py | 74 +-
examples/tutorial/sequence_parallel/README.md | 147 +-
examples/tutorial/sequence_parallel/config.py | 15 +-
.../sequence_parallel/requirements.txt | 4 +-
.../tutorial/sequence_parallel/test_ci.sh | 7 +
examples/tutorial/sequence_parallel/train.py | 44 +-
examples/tutorial/stable_diffusion/LICENSE | 82 -
examples/tutorial/stable_diffusion/README.md | 149 --
.../configs/train_colossalai.yaml | 116 --
.../configs/train_colossalai_cifar10.yaml | 123 --
.../stable_diffusion/configs/train_ddp.yaml | 113 --
.../configs/train_pokemon.yaml | 121 --
.../stable_diffusion/environment.yaml | 34 -
.../stable_diffusion/ldm/data/base.py | 75 -
.../stable_diffusion/ldm/data/cifar10.py | 184 --
.../stable_diffusion/ldm/data/imagenet.py | 394 -----
.../stable_diffusion/ldm/data/lsun.py | 92 -
.../stable_diffusion/ldm/lr_scheduler.py | 98 --
.../ldm/models/autoencoder.py | 544 ------
.../ldm/models/diffusion/classifier.py | 267 ---
.../ldm/models/diffusion/ddim.py | 240 ---
.../ldm/models/diffusion/ddpm.py | 1554 -----------------
.../ldm/models/diffusion/plms.py | 236 ---
.../stable_diffusion/ldm/modules/attention.py | 314 ----
.../ldm/modules/diffusionmodules/__init__.py | 0
.../ldm/modules/diffusionmodules/model.py | 862 ---------
.../modules/diffusionmodules/openaimodel.py | 1152 ------------
.../ldm/modules/diffusionmodules/util.py | 276 ---
.../ldm/modules/distributions/__init__.py | 0
.../modules/distributions/distributions.py | 92 -
.../stable_diffusion/ldm/modules/ema.py | 76 -
.../ldm/modules/encoders/__init__.py | 0
.../ldm/modules/encoders/modules.py | 264 ---
.../ldm/modules/flash_attention.py | 50 -
.../ldm/modules/image_degradation/__init__.py | 2 -
.../ldm/modules/image_degradation/bsrgan.py | 730 --------
.../modules/image_degradation/bsrgan_light.py | 650 -------
.../modules/image_degradation/utils/test.png | Bin 441072 -> 0 bytes
.../modules/image_degradation/utils_image.py | 916 ----------
.../ldm/modules/losses/__init__.py | 1 -
.../ldm/modules/losses/contperceptual.py | 111 --
.../ldm/modules/losses/vqperceptual.py | 167 --
.../ldm/modules/x_transformer.py | 641 -------
.../tutorial/stable_diffusion/ldm/util.py | 203 ---
examples/tutorial/stable_diffusion/main.py | 830 ---------
.../stable_diffusion/requirements.txt | 22 -
.../scripts/download_first_stages.sh | 41 -
.../scripts/download_models.sh | 49 -
.../stable_diffusion/scripts/img2img.py | 293 ----
.../stable_diffusion/scripts/inpaint.py | 98 --
.../stable_diffusion/scripts/knn2img.py | 398 -----
.../scripts/sample_diffusion.py | 313 ----
.../scripts/tests/test_checkpoint.py | 37 -
.../scripts/tests/test_watermark.py | 18 -
.../scripts/train_searcher.py | 147 --
.../stable_diffusion/scripts/txt2img.py | 344 ----
examples/tutorial/stable_diffusion/train.sh | 4 -
requirements/requirements-test.txt | 1 +
setup.py | 33 +-
.../test_tensor_shard/test_checkpoint.py | 70 +
.../test_binary_elementwise_handler.py | 65 +-
.../test_node_handler/test_shard_option.py | 112 ++
.../test_node_handler/utils.py | 5 +-
.../benchmark_simple_evoformer.py | 94 +
.../test_autochunk/test_evoformer_codegen.py | 163 ++
.../test_evoformer_stack_codegen.py | 163 ++
tests/test_autochunk/test_extramsa_codegen.py | 164 ++
.../test_simple_evoformer_codegen.py | 104 ++
.../test_simple_evoformer_search.py | 97 +
tests/test_gemini/update/test_grad_clip.py | 2 -
tests/test_gemini/update/test_inference.py | 122 ++
tests/test_gemini/update/test_optim.py | 2 -
.../update/test_zeroddp_state_dict.py | 16 +-
tests/test_tensor/common_utils/_utils.py | 17 +-
tests/test_tensor/test_tp_with_zero.py | 4 +-
.../test_zero/low_level_zero/test_grad_acc.py | 5 +-
.../test_zero/low_level_zero/test_zero1_2.py | 2 +-
.../low_level_zero/test_zero_init.py | 61 +
.../test_zero/low_level_zero/test_zero_tp.py | 98 ++
215 files changed, 8523 insertions(+), 14916 deletions(-)
create mode 100644 .bdist.json
create mode 100644 .compatibility
create mode 100644 .github/workflows/README.md
create mode 100644 .github/workflows/auto_compatibility_test.yml
create mode 100644 .github/workflows/auto_example_check.yml
create mode 100644 .github/workflows/auto_release_bdist.yml
delete mode 100644 .github/workflows/changed_file_trigger_examples_check_and_weekly_check.yml
rename .github/workflows/{compatibility_test.yml => dispatch_compatibility_test.yml} (98%)
rename .github/workflows/{workflow_dispatch_example.yml => dispatch_example_check.yml} (57%)
create mode 100644 .github/workflows/pre_commit.yml
create mode 100644 .github/workflows/report_precommit_failure.yml
create mode 100644 .github/workflows/report_test_coverage.yml
create mode 100644 .github/workflows/scripts/example_checks/check_dispatch_inputs.py
rename .github/workflows/scripts/{weekly_check_example.py => example_checks/check_example_weekly.py} (76%)
rename .github/workflows/scripts/{changed_example.py => example_checks/detect_changed_example.py} (52%)
delete mode 100644 .github/workflows/scripts/input_check_example.py
create mode 100644 .github/workflows/translate_comment.yml
create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/option.py
create mode 100644 colossalai/autochunk/autochunk_codegen.py
create mode 100644 colossalai/autochunk/estimate_memory.py
create mode 100644 colossalai/autochunk/reorder_graph.py
create mode 100644 colossalai/autochunk/search_chunk.py
create mode 100644 colossalai/autochunk/select_chunk.py
create mode 100644 colossalai/autochunk/trace_flow.py
create mode 100644 colossalai/autochunk/trace_indice.py
create mode 100644 colossalai/autochunk/utils.py
create mode 100644 colossalai/utils/model/experimental.py
rename examples/{tutorial/stable_diffusion/ldm/data/__init__.py => images/diffusion/test_ci.sh} (100%)
rename examples/{tutorial/stable_diffusion/ldm/models/diffusion/__init__.py => images/dreambooth/test_ci.sh} (100%)
create mode 100644 examples/images/vit/configs/vit_1d_tp2_ci.py
create mode 100644 examples/images/vit/test_ci.sh
create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt
create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt
create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt
create mode 100644 examples/language/gpt/experiments/pipeline_parallel/requirements.txt
create mode 100644 examples/language/gpt/gemini/requirements.txt
create mode 100644 examples/language/gpt/gemini/test_ci.sh
create mode 100644 examples/language/gpt/titans/LICENSE
create mode 100644 examples/language/gpt/titans/README.md
create mode 100644 examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
create mode 100644 examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
create mode 100644 examples/language/gpt/titans/dataset/webtext.py
create mode 100644 examples/language/gpt/titans/model/__init__.py
create mode 100644 examples/language/gpt/titans/model/embed.py
create mode 100644 examples/language/gpt/titans/model/gpt1d.py
create mode 100644 examples/language/gpt/titans/model/pipeline_gpt1d.py
create mode 100644 examples/language/gpt/titans/requirements.txt
create mode 100644 examples/language/gpt/titans/run.sh
create mode 100644 examples/language/gpt/titans/test_ci.sh
create mode 100644 examples/language/gpt/titans/train_gpt.py
create mode 100644 examples/language/opt/requirements.txt
create mode 100644 examples/language/opt/test_ci.sh
create mode 100644 examples/language/palm/test_ci.sh
rename examples/tutorial/{stable_diffusion => auto_parallel}/setup.py (68%)
create mode 100644 examples/tutorial/auto_parallel/test_ci.sh
create mode 100644 examples/tutorial/hybrid_parallel/test_ci.sh
create mode 100644 examples/tutorial/large_batch_optimizer/test_ci.sh
create mode 100644 examples/tutorial/sequence_parallel/test_ci.sh
delete mode 100644 examples/tutorial/stable_diffusion/LICENSE
delete mode 100644 examples/tutorial/stable_diffusion/README.md
delete mode 100644 examples/tutorial/stable_diffusion/configs/train_colossalai.yaml
delete mode 100644 examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml
delete mode 100644 examples/tutorial/stable_diffusion/configs/train_ddp.yaml
delete mode 100644 examples/tutorial/stable_diffusion/configs/train_pokemon.yaml
delete mode 100644 examples/tutorial/stable_diffusion/environment.yaml
delete mode 100644 examples/tutorial/stable_diffusion/ldm/data/base.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/data/cifar10.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/data/imagenet.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/data/lsun.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/lr_scheduler.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/models/autoencoder.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/attention.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/ema.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py
delete mode 100644 examples/tutorial/stable_diffusion/ldm/util.py
delete mode 100644 examples/tutorial/stable_diffusion/main.py
delete mode 100644 examples/tutorial/stable_diffusion/requirements.txt
delete mode 100644 examples/tutorial/stable_diffusion/scripts/download_first_stages.sh
delete mode 100644 examples/tutorial/stable_diffusion/scripts/download_models.sh
delete mode 100644 examples/tutorial/stable_diffusion/scripts/img2img.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/inpaint.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/knn2img.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/sample_diffusion.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/train_searcher.py
delete mode 100644 examples/tutorial/stable_diffusion/scripts/txt2img.py
delete mode 100644 examples/tutorial/stable_diffusion/train.sh
create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
create mode 100644 tests/test_autochunk/benchmark_simple_evoformer.py
create mode 100644 tests/test_autochunk/test_evoformer_codegen.py
create mode 100644 tests/test_autochunk/test_evoformer_stack_codegen.py
create mode 100644 tests/test_autochunk/test_extramsa_codegen.py
create mode 100644 tests/test_autochunk/test_simple_evoformer_codegen.py
create mode 100644 tests/test_autochunk/test_simple_evoformer_search.py
create mode 100644 tests/test_gemini/update/test_inference.py
create mode 100644 tests/test_zero/low_level_zero/test_zero_init.py
create mode 100644 tests/test_zero/low_level_zero/test_zero_tp.py
diff --git a/.bdist.json b/.bdist.json
new file mode 100644
index 000000000..8693bca48
--- /dev/null
+++ b/.bdist.json
@@ -0,0 +1,24 @@
+{
+ "build": [
+ {
+ "torch_version": "1.11.0",
+ "cuda_image": "hpcaitech/cuda-conda:10.2"
+ },
+ {
+ "torch_version": "1.11.0",
+ "cuda_image": "hpcaitech/cuda-conda:11.3"
+ },
+ {
+ "torch_version": "1.12.1",
+ "cuda_image": "hpcaitech/cuda-conda:10.2"
+ },
+ {
+ "torch_version": "1.12.1",
+ "cuda_image": "hpcaitech/cuda-conda:11.3"
+ },
+ {
+ "torch_version": "1.12.1",
+ "cuda_image": "hpcaitech/cuda-conda:11.6"
+ }
+ ]
+}
diff --git a/.compatibility b/.compatibility
new file mode 100644
index 000000000..c8ac4083d
--- /dev/null
+++ b/.compatibility
@@ -0,0 +1,3 @@
+1.12.0-11.3.0
+1.11.0-11.3.0
+1.10.1-11.3.0
diff --git a/.github/workflows/README.md b/.github/workflows/README.md
new file mode 100644
index 000000000..cda6a3139
--- /dev/null
+++ b/.github/workflows/README.md
@@ -0,0 +1,149 @@
+# CI/CD
+
+## Table of Contents
+
+- [CI/CD](#cicd)
+ - [Table of Contents](#table-of-contents)
+ - [Overview](#overview)
+ - [Workflows](#workflows)
+ - [Checks on Pull Requests](#checks-on-pull-requests)
+ - [Regular Checks](#regular-checks)
+ - [Release](#release)
+ - [Manual Dispatch](#manual-dispatch)
+ - [Release bdist wheel](#release-bdist-wheel)
+ - [Dispatch Example Test](#dispatch-example-test)
+ - [Compatibility Test](#compatibility-test)
+ - [User Friendliness](#user-friendliness)
+ - [Configuration](#configuration)
+ - [Progress Log](#progress-log)
+
+## Overview
+
+Automation makes our development more efficient as the machine automatically run the pre-defined tasks for the contributors.
+This saves a lot of manual work and allow the developer to fully focus on the features and bug fixes.
+In Colossal-AI, we use [GitHub Actions](https://github.com/features/actions) to automate a wide range of workflows to ensure the robustness of the software.
+In the section below, we will dive into the details of different workflows available.
+
+## Workflows
+
+### Checks on Pull Requests
+
+| Workflow Name | File name | Description |
+| --------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `Build` | `build.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. |
+| `Pre-commit` | `pre_commit.yml` | This workflow runs pre-commit checks for code style consistency. |
+| `Report pre-commit failure` | `report_precommit_failure.yml` | This PR will put up a comment in the PR to explain the precommit failure and remedy. This is executed when `Pre-commit` is done |
+| `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results. This is executed when `Build` is completed. |
+| `Test example` | `auto_example_check.yml` | The example will be automatically tested if its files are changed in the PR |
+
+### Regular Checks
+
+| Workflow Name | File name | Description |
+| ----------------------- | ----------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `Test example` | `auto_example_check.yml` | This workflow will test all examples every Sunday |
+| `Compatibility Test` | `auto_compatibility_test.yml` | This workflow will check the compatiblity of Colossal-AI against PyTorch and CUDA every Sunday. The PyTorch and CUDA versions are specified in `.compatibility`. |
+| `Build on 8 GPUs` | `build_gpu_8.yml` | This workflow will run the unit tests everyday with 8 GPUs. |
+| `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. |
+| `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. |
+
+### Release
+
+| Workflow Name | File name | Description |
+| --------------------------- | ------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `Draft GitHub Release Post` | `draft_github_release_post.yml` | Compose a GitHub release post draft based on the commit history. Triggered when the change of `version.txt` is merged. |
+| `Release to PyPI` | `release_pypi.yml` | Build and release the wheel to PyPI. Triggered when the change of `version.txt` is merged. |
+| `Release Nightly to PyPI` | `release_nightly.yml` | Build and release the nightly wheel to PyPI as `colossalai-nightly`. Automatically executed every Sunday. |
+| `Release Docker` | `release_docker.yml` | Build and release the Docker image to DockerHub. Triggered when the change of `version.txt` is merged. |
+| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. Manually dispatched. See more details in the next section. |
+| `Auto Release bdist wheel` | `auto_release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions.Triggered when the change of `version.txt` is merged. Build specificatons are stored in `.bdist.json` |
+| `Auto Compatibility Test` | `auto_compatibility_test.yml` | Check Colossal-AI's compatiblity against the PyTorch and CUDA version specified in `.compatibility`. Triggered when `version.txt` is changed in a PR. |
+
+### Manual Dispatch
+
+| Workflow Name | File name | Description |
+| ---------------------------- | -------------------------------- | ------------------------------------------------------ |
+| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. |
+| `Dispatch Example Test` | `dispatch_example_check.yml` | Manually test a specified example. |
+| `Dispatch Compatiblity Test` | `dispatch_compatiblity_test.yml` | Test PyTorch and Python Compatibility. |
+
+Refer to this [documentation](https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow) on how to manually trigger a workflow.
+I will provide the details of each workflow below.
+
+#### Release bdist wheel
+
+Parameters:
+- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels) which is regularly updated.
+- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
+- `ref`: input the branch or tag name to build the wheel for this ref.
+
+#### Dispatch Example Test
+
+parameters:
+- `example_directory`: the example directory to test. Multiple directories are supported and must be separated by comma. For example, language/gpt, images/vit. Simply input language or simply gpt does not work.
+
+
+#### Compatibility Test
+
+Parameters:
+- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels).
+- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
+
+> It only test the compatiblity of the main branch
+
+
+### User Friendliness
+
+| Workflow Name | File name | Description |
+| ----------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------- |
+| `issue-translate` | `translate_comment.yml` | This workflow is triggered when a new issue comment is created. The comment will be translated into English if not written in English. |
+
+
+## Configuration
+
+This section lists the files used to configure the workflow.
+
+1. `.compatibility`
+
+This `.compatibility` file is to tell GitHub Actions which PyTorch and CUDA versions to test against. Each line in the file is in the format `${torch-version}-${cuda-version}`, which is a tag for Docker image. Thus, this tag must be present in the [docker registry](https://hub.docker.com/r/pytorch/conda-cuda) so as to perform the test.
+
+2. `.bdist.json`
+
+This file controls what pytorch/cuda compatible pre-built releases will be built and published. You can add a new entry according to the json schema below if there is a new wheel that needs to be built with AOT compilation of PyTorch extensions.
+
+```json
+{
+ "build": [
+ {
+ "torch_version": "",
+ "cuda_image": ""
+ },
+ ]
+}
+```
+
+## Progress Log
+
+- [x] unit testing
+ - [x] test on PR
+ - [x] report test coverage
+ - [x] regular test
+- [x] release
+ - [x] official release
+ - [x] nightly build
+ - [x] binary build
+ - [x] docker build
+ - [x] draft release post
+- [x] pre-commit
+ - [x] check on PR
+ - [x] report failure
+- [x] example check
+ - [x] check on PR
+ - [x] regular check
+ - [x] manual dispatch
+- [x] compatiblity check
+ - [x] manual dispatch
+ - [x] auto test when release
+- [x] helpers
+ - [x] comment translation
+ - [x] submodule update
+ - [x] close inactive issue
diff --git a/.github/workflows/auto_compatibility_test.yml b/.github/workflows/auto_compatibility_test.yml
new file mode 100644
index 000000000..4b026c63e
--- /dev/null
+++ b/.github/workflows/auto_compatibility_test.yml
@@ -0,0 +1,74 @@
+name: Compatibility Test
+
+on:
+ pull_request:
+ paths:
+ - 'version.txt'
+ - '.compatibility'
+ # run at 03:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
+ schedule:
+ - cron: '0 19 * * 6'
+
+jobs:
+ matrix_preparation:
+ name: Prepare Container List
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v3
+ - id: set-matrix
+ run: |
+ IFS=','
+ DOCKER_IMAGE=()
+
+ while read tag; do
+ DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tag}\"")
+ done <.compatibility
+
+ container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
+ container="[${container}]"
+ echo "$container"
+ echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
+
+ build:
+ name: Test for PyTorch Compatibility
+ needs: matrix_preparation
+ if: github.repository == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ strategy:
+ fail-fast: false
+ matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
+ container:
+ image: ${{ matrix.container }}
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
+ timeout-minutes: 120
+ steps:
+ - name: Install dependencies
+ run: |
+ pip install -U pip setuptools wheel --user
+ - uses: actions/checkout@v2
+ with:
+ repository: hpcaitech/TensorNVMe
+ ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ path: TensorNVMe
+ - name: Install tensornvme
+ run: |
+ cd TensorNVMe
+ conda install cmake
+ pip install -r requirements.txt
+ pip install -v .
+ - uses: actions/checkout@v2
+ with:
+ ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ - name: Install Colossal-AI
+ run: |
+ pip install -v --no-cache-dir .
+ pip install -r requirements/requirements-test.txt
+ - name: Unit Testing
+ run: |
+ PYTHONPATH=$PWD pytest tests
+ env:
+ DATA: /data/scratch/cifar-10
+ NCCL_SHM_DISABLE: 1
+ LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
diff --git a/.github/workflows/auto_example_check.yml b/.github/workflows/auto_example_check.yml
new file mode 100644
index 000000000..df413f646
--- /dev/null
+++ b/.github/workflows/auto_example_check.yml
@@ -0,0 +1,143 @@
+name: Test Example
+on:
+ pull_request:
+ # any change in the examples folder will trigger check for the corresponding example.
+ paths:
+ - 'examples/**'
+ # run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
+ schedule:
+ - cron: '0 16 * * 6'
+
+jobs:
+ # This is for changed example files detect and output a matrix containing all the corresponding directory name.
+ detect-changed-example:
+ if: |
+ github.event.pull_request.draft == false &&
+ github.base_ref == 'main' &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.setup-matrix.outputs.matrix }}
+ anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
+ name: Detect changed example files
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ ref: ${{ github.event.pull_request.head.sha }}
+
+ - name: Locate base commit
+ id: locate-base-sha
+ run: |
+ curBranch=$(git rev-parse --abbrev-ref HEAD)
+ commonCommit=$(git merge-base origin/main $curBranch)
+ echo $commonCommit
+ echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
+
+ - name: Get all changed example files
+ id: changed-files
+ uses: tj-actions/changed-files@v35
+ with:
+ base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
+
+ - name: setup matrix
+ id: setup-matrix
+ run: |
+ changedFileName=""
+ for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
+ changedFileName="${file}:${changedFileName}"
+ done
+ echo "$changedFileName was changed"
+ res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName`
+ echo "All changed examples are $res"
+
+ if [ "$res" = "[]" ]; then
+ echo "anyChanged=false" >> $GITHUB_OUTPUT
+ echo "matrix=null" >> $GITHUB_OUTPUT
+ else
+ dirs=$( IFS=',' ; echo "${res[*]}" )
+ echo "anyChanged=true" >> $GITHUB_OUTPUT
+ echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
+ fi
+
+ # If no file is changed, it will prompt an error and shows the matrix do not have value.
+ check-changed-example:
+ # Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
+ if: |
+ github.event.pull_request.draft == false &&
+ github.base_ref == 'main' &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
+ needs.detect-changed-example.outputs.anyChanged == 'true'
+ name: Test the changed example
+ needs: detect-changed-example
+ runs-on: [self-hosted, gpu]
+ strategy:
+ matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --gpus all --rm -v /data/scratch/examples-data:/data/
+ timeout-minutes: 10
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Install Colossal-AI
+ run: |
+ pip install -v .
+
+ - name: Test the example
+ run: |
+ example_dir=${{ matrix.directory }}
+ cd "${PWD}/examples/${example_dir}"
+ bash test_ci.sh
+ env:
+ NCCL_SHM_DISABLE: 1
+
+ # This is for all files' weekly check. Specifically, this job is to find all the directories.
+ matrix_preparation:
+ if: |
+ github.repository == 'hpcaitech/ColossalAI' &&
+ github.event_name == 'schedule'
+ name: Prepare matrix for weekly check
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.setup-matrix.outputs.matrix }}
+ steps:
+ - name: 📚 Checkout
+ uses: actions/checkout@v3
+
+ - name: setup matrix
+ id: setup-matrix
+ run: |
+ res=`python .github/workflows/scripts/example_checks/check_example_weekly.py`
+ all_loc=$( IFS=',' ; echo "${res[*]}" )
+ echo "Found the examples: $all_loc"
+ echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT
+
+ weekly_check:
+ if: |
+ github.repository == 'hpcaitech/ColossalAI' &&
+ github.event_name == 'schedule'
+ name: Weekly check all examples
+ needs: matrix_preparation
+ runs-on: [self-hosted, gpu]
+ strategy:
+ matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ timeout-minutes: 10
+ steps:
+ - name: 📚 Checkout
+ uses: actions/checkout@v3
+
+ - name: Install Colossal-AI
+ run: |
+ pip install -v .
+
+ - name: Traverse all files
+ run: |
+ example_dir=${{ matrix.diretory }}
+ echo "Testing ${example_dir} now"
+ cd "${PWD}/examples/${example_dir}"
+ bash test_ci.sh
+ env:
+ NCCL_SHM_DISABLE: 1
diff --git a/.github/workflows/auto_release_bdist.yml b/.github/workflows/auto_release_bdist.yml
new file mode 100644
index 000000000..56a3036f8
--- /dev/null
+++ b/.github/workflows/auto_release_bdist.yml
@@ -0,0 +1,70 @@
+name: Auto Release bdist wheel
+
+on:
+ workflow_dispatch:
+ pull_request:
+ paths:
+ - 'version.txt'
+ types:
+ - closed
+
+jobs:
+ matrix_preparation:
+ name: Prepare Container List
+ if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v3
+ - id: set-matrix
+ run: |
+ bdist=$(cat .bdist.json | tr '\n' ' ')
+ echo "matrix=${bdist}" >> $GITHUB_OUTPUT
+
+ build:
+ name: Release bdist wheels
+ needs: matrix_preparation
+ runs-on: [self-hosted, gpu]
+ strategy:
+ fail-fast: false
+ matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
+ container:
+ image: ${{ matrix.build.cuda_image }}
+ options: --gpus all --rm
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+ # cub is for cuda 10.2
+ - name: Copy scripts
+ run: |
+ cp -r ./.github/workflows/scripts/* ./
+
+ # link the cache diretories to current path
+ ln -s /github/home/conda_pkgs ./conda_pkgs
+ ln -s /github/home/pip_wheels ./pip_wheels
+
+ # set the conda package path
+ echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc
+
+ # set safe directory
+ git config --global --add safe.directory /__w/ColossalAI/ColossalAI
+
+ # get cub package for cuda 10.2
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ - name: Build bdist wheel
+ run: |
+ pip install beautifulsoup4 requests packaging
+ python ./build_colossalai_wheel.py --torch_version $TORCH_VERSIONS
+ env:
+ TORCH_VERSIONS: ${{ matrix.build.torch_version }}
+ - name: 🚀 Deploy
+ uses: garygrossgarten/github-action-scp@release
+ with:
+ local: all_dist
+ remote: ${{ secrets.PRIVATE_PYPI_DIR }}
+ host: ${{ secrets.PRIVATE_PYPI_HOST }}
+ username: ${{ secrets.PRIVATE_PYPI_USER }}
+ password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 5366f69cc..8f334d599 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -20,15 +20,26 @@ jobs:
- uses: actions/checkout@v2
with:
fetch-depth: 0
+ ref: ${{ github.event.pull_request.head.sha }}
+
+ - name: Locate base commit
+ id: locate-base-sha
+ run: |
+ curBranch=$(git rev-parse --abbrev-ref HEAD)
+ commonCommit=$(git merge-base origin/main $curBranch)
+ echo $commonCommit
+ echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
+
- name: Find the changed files
id: find-changed-files
uses: tj-actions/changed-files@v35
with:
- since_last_remote_commit: true
+ base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
files: |
op_builder/**
colossalai/kernel/**
setup.py
+
- name: List changed files
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
@@ -75,12 +86,26 @@ jobs:
- name: Unit Testing
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ - name: Collate artifact
+ env:
+ PR_NUMBER: ${{ github.event.number }}
+ run: |
+ mkdir report
+ echo $PR_NUMBER > ./report/pr_number
+ mv coverage.xml ./report
+
+ - name: Upload test coverage artifact
+ uses: actions/upload-artifact@v3
+ with:
+ name: report
+ path: report/
+
- name: Store Cache
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
diff --git a/.github/workflows/changed_file_trigger_examples_check_and_weekly_check.yml b/.github/workflows/changed_file_trigger_examples_check_and_weekly_check.yml
deleted file mode 100644
index 2b7ec3125..000000000
--- a/.github/workflows/changed_file_trigger_examples_check_and_weekly_check.yml
+++ /dev/null
@@ -1,119 +0,0 @@
-name: Test Example
-on:
- pull_request:
- # So only the changes in examples folder will trigger jobs below.
- paths:
- - 'examples/**'
- # run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
- schedule:
- - cron: '0 16 * * 6'
-
-jobs:
- # This is for changed example files detect and output a matrix containing all the corresponding directory name.
- detect-changed-example:
- if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
- runs-on: ubuntu-latest
- outputs:
- matrix: ${{ steps.set-matrix.outputs.matrix }}
- name: Check out all files
- steps:
- - uses: actions/checkout@v3
- with:
- fetch-depth: 2
- - name: Get all changed example files
- id: changed-files
- uses: tj-actions/changed-files@v35
- # Using this can trigger action each time a PR is submitted.
- with:
- since_last_remote_commit: true
- - name: setup matrix
- id: set-matrix
- run: |
- changedFileName=""
- for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
- changedFileName="${file}:${changedFileName}"
- done
- echo "$changedFileName was changed"
- res=`python .github/workflows/scripts/changed_example.py --fileNameList $changedFileName`
- echo "All changed files are $res"
- loc=$( IFS=',' ; echo "${res[*]}" )
- echo "$loc"
- echo "::set-output name=matrix::{\"loc\":$(echo "$loc")}"
-
- # If no file is changed, it will prompt an error and shows the matrix do not have value.
- check-all-changed-files:
- # Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
- if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
- name: Test each changed example files
- needs: detect-changed-example
- runs-on: [self-hosted, gpu]
- strategy:
- matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
- container:
- image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
- steps:
- - uses: actions/checkout@v3
- with:
- fetch-depth: 2
- - name: Install dependancies
- run: |
- pip install -r ./requirements/requirements.txt
- pip install colossalai
- - name: List all changed example files
- run: |
- res=${{ matrix.loc }}
- cd "${PWD}/examples/${res}"
- bash test_ci.sh
-
- # This is for all files' weekly check. Specifically, this job is to find all the directories.
- matrix_preparation:
- if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
- name: Prepare Directory List for All files
- runs-on: ubuntu-latest
- outputs:
- matrix: ${{ steps.set-matrix.outputs.matrix }}
- steps:
- - name: 📚 Checkout
- uses: actions/checkout@v3
- - name: setup matrix
- id: set-matrix
- run: |
- res=`python .github/workflows/scripts/weekly_check_example.py`
- all_loc=$( IFS=',' ; echo "${res[*]}" )
- echo "$all_loc"
- echo "::set-output name=matrix::{\"all_loc\":$(echo "$all_loc")}"
-
- weekly_check:
- if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
- name: Weekly check all examples
- needs: matrix_preparation
- runs-on: [self-hosted, gpu]
- strategy:
- matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
- container:
- image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
- steps:
- - name: 📚 Checkout
- uses: actions/checkout@v3
- - name: Install the requirements
- run: |
- pip install -r ./requirements/requirements.txt
- pip install colossalai
- - name: Traverse all files
- run: |
- dir=${{ matrix.all_loc }}
- echo "${dir} is current directory"
- cd "${PWD}/examples/${dir}"
- bash test_ci.sh
diff --git a/.github/workflows/compatibility_test.yml b/.github/workflows/dispatch_compatibility_test.yml
similarity index 98%
rename from .github/workflows/compatibility_test.yml
rename to .github/workflows/dispatch_compatibility_test.yml
index eadd07886..ac5669c6f 100644
--- a/.github/workflows/compatibility_test.yml
+++ b/.github/workflows/dispatch_compatibility_test.yml
@@ -1,4 +1,4 @@
-name: Compatibility Test
+name: Dispatch Compatibility Test
on:
workflow_dispatch:
diff --git a/.github/workflows/workflow_dispatch_example.yml b/.github/workflows/dispatch_example_check.yml
similarity index 57%
rename from .github/workflows/workflow_dispatch_example.yml
rename to .github/workflows/dispatch_example_check.yml
index d9d576910..e0333422f 100644
--- a/.github/workflows/workflow_dispatch_example.yml
+++ b/.github/workflows/dispatch_example_check.yml
@@ -8,7 +8,7 @@ on:
required: true
jobs:
- manual_check_matrix_preparation:
+ matrix_preparation:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
@@ -16,31 +16,24 @@ jobs:
name: Check the examples user want
runs-on: ubuntu-latest
outputs:
- matrix: ${{ steps.set-matrix-1.outputs.matrix }}
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- - name: Get manual directories
- id: set-matrix-1
+ - name: Set up matrix
+ id: set-matrix
env:
check_dir: ${{ inputs.example_directory }}
run: |
- all_mannual_check_dir=()
- for cdi in $check_dir
- do
- all_mannual_check_dir+=("\"${cdi}\"")
- done
- man_loc=$( IFS=',' ; echo "${all_mannual_check_dir[*]}" )
- res=`python .github/workflows/scripts/input_check_example.py --fileNameList $man_loc`
- echo "${res} is file existance. 1 for all exist, -1 for at least one file not exist."
- if [ res == -1 ];then
- exit(1)
+ res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir`
+ if [ res == "failure" ];then
+ exit -1
fi
- man_loc="[${man_loc}]"
- echo "$man_loc"
- echo "::set-output name=matrix::{\"man_loc\":$(echo "$man_loc")}"
+ dirs="[${check_dir}]"
+ echo "Testing examples in $dirs"
+ echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
- manual_check:
+ test_example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
@@ -52,16 +45,19 @@ jobs:
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --gpus all --rm -v /data/scratch/examples-data:/data/
+ timeout-minutes: 10
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- - name: Install the requirements
+ - name: Install Colossal-AI
run: |
- pip install -r ./requirements/requirements.txt
- pip install colossalai
- - name: Traverse all files
+ pip install -v .
+ - name: Test the example
run: |
- dir=${{ matrix.man_loc }}
- echo "${dir} is current directory"
+ dir=${{ matrix.directory }}
+ echo "Testing ${dir} now"
cd "${PWD}/examples/${dir}"
bash test_ci.sh
+ env:
+ NCCL_SHM_DISABLE: 1
diff --git a/.github/workflows/draft_github_release_post.yml b/.github/workflows/draft_github_release_post.yml
index 413714daf..53bfa9e8d 100644
--- a/.github/workflows/draft_github_release_post.yml
+++ b/.github/workflows/draft_github_release_post.yml
@@ -8,11 +8,10 @@ on:
types:
- closed
-
jobs:
release:
name: Draft Release Post
- if: github.repository == 'hpcaitech/ColossalAI'
+ if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/pre_commit.yml b/.github/workflows/pre_commit.yml
new file mode 100644
index 000000000..3e71be2fc
--- /dev/null
+++ b/.github/workflows/pre_commit.yml
@@ -0,0 +1,71 @@
+name: pre-commit
+
+on:
+ pull_request:
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+ ref: ${{ github.event.pull_request.head.sha }}
+
+ # the PR branch and the hpcaitech/colossal-ai main branch
+ # must share a common commit, we need to locate that commit,
+ # which is the commit checked-out or forked when the PR branch is created
+ # such that we can look for files changed since that commit
+ - name: Locate base commit
+ id: locate-base-sha
+ run: |
+ curBranch=$(git rev-parse --abbrev-ref HEAD)
+ commonCommit=$(git merge-base origin/main $curBranch)
+ echo $commonCommit
+ echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
+
+ - name: Find the changed files
+ id: find-changed-files
+ uses: tj-actions/changed-files@v35
+ with:
+ base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
+
+ - name: List all changed files
+ run: |
+ for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
+ echo "$file was changed"
+ done
+
+ - uses: actions/setup-python@v3
+
+ - name: Cache pre-commit hooks
+ uses: actions/cache@v3
+ with:
+ path: ~/.cache/pre-commit
+ key: ${{ runner.os }}-pre-commit-hooks
+
+ - name: Set up pre-commit
+ run: |
+ pip install pre-commit
+ pre-commit install
+
+ - name: Run pre-commit on Changed Files
+ id: precommit
+ run: |
+ for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
+ echo "======= running pre-commit on ${file} ======="
+ pre-commit run --files $file
+ done
+
+ - name: Save PR number
+ if: always()
+ env:
+ PR_NUMBER: ${{ github.event.number }}
+ run: |
+ mkdir -p ./pr
+ echo $PR_NUMBER > ./pr/pr_number
+ - uses: actions/upload-artifact@v3
+ if: always()
+ with:
+ name: pr_number
+ path: pr/
diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml
index 328d232a8..8da6e5f87 100644
--- a/.github/workflows/release_docker.yml
+++ b/.github/workflows/release_docker.yml
@@ -2,13 +2,16 @@ name: Publish Docker Image to DockerHub
on:
workflow_dispatch:
- release:
- types: [published]
+ pull_request:
+ paths:
+ - 'version.txt'
+ types:
+ - closed
jobs:
release:
name: Publish Docker Image to DockerHub
- if: github.repository == 'hpcaitech/ColossalAI'
+ if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: "hpcaitech/docker-in-docker:latest"
@@ -18,23 +21,17 @@ jobs:
with:
fetch-depth: 0
- name: Build Docker
+ id: build
run: |
version=$(cat version.txt)
- docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker
+ tag=hpcaitech/colossalai:$version
+ docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
+ echo "tag=${tag}" >> $GITHUB_OUTPUT
- name: Log in to Docker Hub
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - name: Extract metadata (tags, labels) for Docker
- id: meta
- uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
- with:
- images: hpcaitech/colossalai
- - name: Build and push Docker image
- uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
- with:
- context: .
- push: true
- tags: ${{ steps.meta.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels }}
+ - name: Push Docker image
+ run: |
+ docker push ${{ steps.build.outputs.tag }}
diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml
index 6bc000d1f..8aa48b8ed 100644
--- a/.github/workflows/release_nightly.yml
+++ b/.github/workflows/release_nightly.yml
@@ -1,73 +1,29 @@
-name: Release bdist wheel for Nightly versions
+name: Publish Nightly Version to PyPI
on:
- schedule:
- # run at 00:00 of every Sunday
- - cron: '0 0 * * 6'
workflow_dispatch:
+ schedule:
+ - cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
jobs:
- matrix_preparation:
- name: Prepare Container List
+ build-n-publish:
+ if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI'
+ name: Build and publish Python 🐍 distributions 📦 to PyPI
runs-on: ubuntu-latest
- outputs:
- matrix: ${{ steps.set-matrix.outputs.matrix }}
+ timeout-minutes: 20
steps:
- - id: set-matrix
- run: |
- matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]"
- echo $matrix
- echo "::set-output name=matrix::{\"container\":$(echo $matrix)}"
+ - uses: actions/checkout@v2
- build:
- name: Release bdist wheels
- needs: matrix_preparation
- if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
- runs-on: [self-hosted, gpu]
- strategy:
- fail-fast: false
- matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
- container:
- image: ${{ matrix.container }}
- options: --gpus all --rm
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
- # cub is for cuda 10.2
- - name: Copy scripts and checkout
- run: |
- cp -r ./.github/workflows/scripts/* ./
- ln -s /github/home/pip_wheels ./pip_wheels
- wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
- unzip 1.8.0.zip
- - name: Build bdist wheel
- run: |
- pip install beautifulsoup4 requests packaging
- python ./build_colossalai_wheel.py --nightly
- - name: 🚀 Deploy
- uses: garygrossgarten/github-action-scp@release
- with:
- local: all_dist
- remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
- host: ${{ secrets.PRIVATE_PYPI_HOST }}
- username: ${{ secrets.PRIVATE_PYPI_USER }}
- password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
- remove_old_build:
- name: Remove old nightly build
- runs-on: ubuntu-latest
- needs: build
- steps:
- - name: executing remote ssh commands using password
- uses: appleboy/ssh-action@master
- env:
- BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
- with:
- host: ${{ secrets.PRIVATE_PYPI_HOST }}
- username: ${{ secrets.PRIVATE_PYPI_USER }}
- password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
- envs: BUILD_DIR
- script: |
- cd $BUILD_DIR
- find . -type f -mtime +0 -exec rm -f {} +
- script_stop: true
+ - uses: actions/setup-python@v2
+ with:
+ python-version: '3.8.14'
+
+ - run: NIGHTLY=1 python setup.py sdist build
+
+ # publish to PyPI if executed on the main branch
+ - name: Publish package to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ user: __token__
+ password: ${{ secrets.PYPI_API_TOKEN }}
+ verbose: true
diff --git a/.github/workflows/report_precommit_failure.yml b/.github/workflows/report_precommit_failure.yml
new file mode 100644
index 000000000..e6ca7b01b
--- /dev/null
+++ b/.github/workflows/report_precommit_failure.yml
@@ -0,0 +1,67 @@
+name: Report Precommit Failure
+
+on:
+ workflow_run:
+ workflows: [pre-commit]
+ types:
+ - completed
+
+jobs:
+ # comment with a message on how to do pre-commit
+ # if the pre-commit check was not passed
+ report-precommit-failure:
+ runs-on: ubuntu-latest
+ if: ${{ github.event.workflow_run.conclusion == 'failure' }}
+ steps:
+ - name: 'Download artifact'
+ uses: actions/github-script@v6
+ with:
+ script: |
+ let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ run_id: context.payload.workflow_run.id,
+ });
+ let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
+ return artifact.name == "pr_number"
+ })[0];
+ let download = await github.rest.actions.downloadArtifact({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ artifact_id: matchArtifact.id,
+ archive_format: 'zip',
+ });
+ let fs = require('fs');
+ fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/pr_number.zip`, Buffer.from(download.data));
+
+ - name: 'Unzip artifact'
+ run: unzip pr_number.zip
+
+ - name: 'Comment on PR'
+ uses: actions/github-script@v6
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ let fs = require('fs');
+ let issue_number = Number(fs.readFileSync('./pr_number'));
+ let owner = context.repo.owner;
+ let repo = context.repo.repo;
+ let run_id = context.payload.workflow_run.id;
+ let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
+ let body = `
+ Your pre-commit check failed, follow the steps to run pre-commit on your file for code style consistency.
+
+ 1. install pre-commit via "pip install pre-commit"
+ 2. install pre-commit hooks via "pre-commit install"
+ 3. run pre-commit on file with format error via "pre-commit run --files path" by replacing "path" with the actual file path
+ 4. commit and push to your branch
+
+ View your job at ${run_url}.
+ Read our "CONTRIBUTING.md" for more reference to the code style.
+ `;
+ await github.rest.issues.createComment({
+ owner: owner,
+ repo: repo,
+ issue_number: issue_number,
+ body: body
+ });
diff --git a/.github/workflows/report_test_coverage.yml b/.github/workflows/report_test_coverage.yml
new file mode 100644
index 000000000..dc3fe395f
--- /dev/null
+++ b/.github/workflows/report_test_coverage.yml
@@ -0,0 +1,74 @@
+name: Report Test Coverage
+
+on:
+ workflow_run:
+ workflows: [Build]
+ types:
+ - completed
+
+jobs:
+ report-test-coverage:
+ runs-on: ubuntu-latest
+ steps:
+ - name: 'Download artifact'
+ uses: actions/github-script@v6
+ with:
+ script: |
+ let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ run_id: context.payload.workflow_run.id,
+ });
+ let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
+ return artifact.name == "report"
+ })[0];
+ let download = await github.rest.actions.downloadArtifact({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ artifact_id: matchArtifact.id,
+ archive_format: 'zip',
+ });
+ let fs = require('fs');
+ fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data));
+
+ - name: 'Unzip artifact'
+ run: |
+ unzip report.zip
+
+ - name: Code Coverage Report
+ uses: irongut/CodeCoverageSummary@v1.3.0
+ with:
+ filename: coverage.xml
+ badge: true
+ format: markdown
+ hide_branch_rate: false
+ hide_complexity: false
+ indicators: true
+ output: both
+ thresholds: '80 90'
+
+ - name: Make Coverage Report Collapsable
+ run: |
+ sed -i '2 i ' code-coverage-results.md
+ sed -i '3 i Click me to view the complete report
' code-coverage-results.md
+ echo " " >> code-coverage-results.md
+
+ - name: 'Comment on PR'
+ uses: actions/github-script@v6
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ let fs = require('fs');
+ let issue_number = Number(fs.readFileSync('./pr_number'));
+ let owner = context.repo.owner;
+ let repo = context.repo.repo;
+ let run_id = context.payload.workflow_run.id;
+ let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
+ let body = fs.readFileSync('./code-coverage-results.md', {encoding:'utf8', flag:'r'})
+
+ await github.rest.issues.createComment({
+ owner: owner,
+ repo: repo,
+ issue_number: issue_number,
+ body: body
+ });
diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
new file mode 100644
index 000000000..04d2063ec
--- /dev/null
+++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
@@ -0,0 +1,27 @@
+import argparse
+import os
+
+
+def check_inputs(input_list):
+ for path in input_list:
+ real_path = os.path.join('examples', path)
+ if not os.path.exists(real_path):
+ return False
+ return True
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
+ args = parser.parse_args()
+ name_list = args.fileNameList.split(",")
+ is_correct = check_inputs(name_list)
+
+ if is_correct:
+ print('success')
+ else:
+ print('failure')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.github/workflows/scripts/weekly_check_example.py b/.github/workflows/scripts/example_checks/check_example_weekly.py
similarity index 76%
rename from .github/workflows/scripts/weekly_check_example.py
rename to .github/workflows/scripts/example_checks/check_example_weekly.py
index dfedc4628..941e90901 100644
--- a/.github/workflows/scripts/weekly_check_example.py
+++ b/.github/workflows/scripts/example_checks/check_example_weekly.py
@@ -5,9 +5,9 @@ def show_files(path, all_files):
# Traverse all the folder/file in current directory
file_list = os.listdir(path)
# Determine the element is folder or file. If file, pass it into list, if folder, recurse.
- for file in file_list:
+ for file_name in file_list:
# Get the abs directory using os.path.join() and store into cur_path.
- cur_path = os.path.join(path, file)
+ cur_path = os.path.join(path, file_name)
# Determine whether folder
if os.path.isdir(cur_path):
show_files(cur_path, all_files)
@@ -26,9 +26,8 @@ def main():
for file_loc in contents:
split_loc = file_loc.split('/')
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
- if len(split_loc) - split_loc.index('examples') >= 3:
- tmp_loc = split_loc[(split_loc.index('examples') + 1):(split_loc.index('examples') + 3)]
- re_loc = join(tmp_loc, '/')
+ if len(split_loc) >= 4:
+ re_loc = '/'.join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)
diff --git a/.github/workflows/scripts/changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py
similarity index 52%
rename from .github/workflows/scripts/changed_example.py
rename to .github/workflows/scripts/example_checks/detect_changed_example.py
index ac2f0864e..df4fd6736 100644
--- a/.github/workflows/scripts/changed_example.py
+++ b/.github/workflows/scripts/example_checks/detect_changed_example.py
@@ -3,14 +3,19 @@ import argparse
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--fileNameList', type=str)
+ parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
for loc in name_list:
- # Find only the sub-folder of 'example' folder
+ # Find only the sub-sub-folder of 'example' folder
+ # the examples folder structure is like
+ # - examples
+ # - area
+ # - application
+ # - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
- folder_need_check.add(loc.split("/")[1] + "/" + loc.split("/")[2])
+ folder_need_check.add('/'.join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))
diff --git a/.github/workflows/scripts/input_check_example.py b/.github/workflows/scripts/input_check_example.py
deleted file mode 100644
index 5602d8f09..000000000
--- a/.github/workflows/scripts/input_check_example.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import argparse
-import os
-
-
-def detect_correct(loc_li):
- for loc in loc_li:
- real_loc = 'examples/' + eval(loc)
- if not os.path.exists(real_loc):
- return -1
- return 1
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--fileNameList', type=str)
- args = parser.parse_args()
- name_list = args.fileNameList.split(",")
- result = detect_correct(name_list)
- print(result)
-
-
-if __name__ == '__main__':
- main()
diff --git a/.github/workflows/translate_comment.yml b/.github/workflows/translate_comment.yml
new file mode 100644
index 000000000..83c127b3c
--- /dev/null
+++ b/.github/workflows/translate_comment.yml
@@ -0,0 +1,18 @@
+name: 'issue-translator'
+on:
+ issue_comment:
+ types: [created]
+ issues:
+ types: [opened]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: usthe/issues-translate-action@v2.7
+ with:
+ IS_MODIFY_TITLE: false
+ # not require, default false, . Decide whether to modify the issue title
+ # if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
+ CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
+ # not require. Customize the translation robot prefix message.
diff --git a/.gitignore b/.gitignore
index 6b6f980e3..bf74a7538 100644
--- a/.gitignore
+++ b/.gitignore
@@ -151,3 +151,7 @@ colossalai/version.py
# ignore python interface defition file
.pyi
+
+# ignore coverage test file
+coverage.lcov
+coverage.xml
diff --git a/README-zh-Hans.md b/README-zh-Hans.md
index 8edcff28b..5ad22785c 100644
--- a/README-zh-Hans.md
+++ b/README-zh-Hans.md
@@ -5,10 +5,10 @@
Colossal-AI: 一个面向大模型时代的通用深度学习系统
- 论文 |
- 文档 |
- 例程 |
- 论坛 |
+
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
@@ -35,7 +35,7 @@
为何选择 Colossal-AI
特点
- 并行训练样例展示
+ 并行训练样例展示
- 单GPU训练样例展示
+ 单GPU训练样例展示
- 推理 (Energon-AI) 样例展示
+ 推理 (Energon-AI) 样例展示
- Colossal-AI 成功案例
+ Colossal-AI 成功案例
- AIGC: 加速 Stable Diffusion
- 生物医药: 加速AlphaFold蛋白质结构预测
@@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 用相同的硬件训练24倍大的模型
-- 超3倍的吞吐量
+- 超3倍的吞吐量
### BERT
@@ -145,7 +145,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型,由于完全公开了预训练参数权重,因此促进了下游任务和应用部署的发展。
-- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
+- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md)
请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。
@@ -199,7 +199,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
-- [OPT推理服务](https://service.colossalai.org/opt): 无需注册,免费体验1750亿参数OPT在线推理服务
+- [OPT推理服务](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md): 无需注册,免费体验1750亿参数OPT在线推理服务
bW7B~ul4I6|D+jv=$b-?3#1;R
F768pwDM|nU
literal 0
HcmV?d00001
diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9b431a45baba43b9581fb5cf3d4bf39a2aaea5d6
GIT binary patch
literal 559
zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfho3-LBpE?19EtC
zc(Zvkdvil&n1O5#AV!C5-n>9@5CoYa3gm$x8x%4FX(k|K2dm)6Fbkp+L4xc+SI_3n
zkg6Bp&Can)mSgS;pg%!40H?qC8KmLVH`JX0-fV0-P(^agx^U%S_W*e?x*94I
d#0X&k^|6CkXQ6x$72wUv1`=ZeLXdihS^z_{RyzOy
literal 0
HcmV?d00001
diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt
new file mode 100644
index 0000000000000000000000000000000000000000..79a448c1b06f1db8731d2d45f988ff0b57810b04
GIT binary patch
literal 943
zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfho3-LBpE?19EtC
zc(Zvkdvil&n1O5#AV!C5-n>9@5CoYa3gm$x8x%4FX(k|K2dm)6Fbkp+L4xeSrykYz
zQM+js=4{>!sd@q4>>LMcl}lup7#Kh}0B2g`XHbWywX)Qr;>`R!Hz#GZq=u62U>svE
zkS!PIrH2A7U;yC&Z$=OWPt(XQ5CBP_0Q3}&t{d58eiWTKKwDtCp>7WFW@FQVDw1Q?
qg)0ZU2grlb)livBPywJmc94)SGem+BNCkMavVnL^KnPL~Q40VIP?Na;
literal 0
HcmV?d00001
diff --git a/examples/language/gpt/experiments/pipeline_parallel/requirements.txt b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt
new file mode 100644
index 000000000..137a69e80
--- /dev/null
+++ b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
index 79efa61b0..c3451c18d 100644
--- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
+++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
@@ -120,7 +120,7 @@ def run_master(args):
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
- pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
+ pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):
diff --git a/examples/language/gpt/gemini/benchmark_gemini.sh b/examples/language/gpt/gemini/benchmark_gemini.sh
index 13086666e..9a630b2ff 100644
--- a/examples/language/gpt/gemini/benchmark_gemini.sh
+++ b/examples/language/gpt/gemini/benchmark_gemini.sh
@@ -1,18 +1,20 @@
for MODEL_TYPE in "gpt2_medium"; do
- for BATCH_SIZE in 16; do
- for GPUNUM in 1 2 4 8; do
- for TPDEGREE in 1 2 4 8; do
- if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
- continue
- fi
- for PLACEMENT in "cpu" "auto"; do
- echo "****************** Begin ***************************"
- echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
- MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
- bash ./gemini/run_gemini.sh
- echo "****************** Finished ***************************"
- echo ""
- echo ""
+ for DISTPLAN in "colossalai"; do
+ for BATCH_SIZE in 16; do
+ for GPUNUM in 1 2 4 8; do
+ for TPDEGREE in 1 2 4 8; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ for PLACEMENT in "cpu" "auto"; do
+ echo "****************** Begin ***************************"
+ echo "+ benchmrking MODEL ${MODEL_TYPE} DISTPLAN ${DISTPLAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}"
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
+ bash ./run_gemini.sh
+ echo "****************** Finished ***************************"
+ echo ""
+ echo ""
+ done
done
done
done
diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py
index c31b3fa6d..65124d9e4 100644
--- a/examples/language/gpt/gemini/commons/model_zoo.py
+++ b/examples/language/gpt/gemini/commons/model_zoo.py
@@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
+def gpt2_30b(checkpoint=True):
+ return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
+
+
+def gpt2_40b(checkpoint=True):
+ return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
+
+
def model_builder(model_size: str) -> callable:
if model_size == "gpt2_medium":
return gpt2_medium
@@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
return gpt2_20b
elif model_size == "gpt2_24b":
return gpt2_24b
+ elif model_size == "gpt2_30b":
+ return gpt2_30b
+ elif model_size == "gpt2_40b":
+ return gpt2_40b
else:
raise TypeError(f"model_builder {model_size}")
diff --git a/examples/language/gpt/gemini/requirements.txt b/examples/language/gpt/gemini/requirements.txt
new file mode 100644
index 000000000..137a69e80
--- /dev/null
+++ b/examples/language/gpt/gemini/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh
index ad577c350..6f0710d54 100644
--- a/examples/language/gpt/gemini/run_gemini.sh
+++ b/examples/language/gpt/gemini/run_gemini.sh
@@ -1,15 +1,15 @@
set -x
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
-export DISTPAN=${DISTPAN:-"colossalai"}
+export DISTPLAN=${DISTPLAN:-"colossalai"}
-# The following options only valid when DISTPAN="colossalai"
+# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
-
+export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH
mkdir -p gemini_logs
@@ -20,5 +20,6 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
--shardinit=${USE_SHARD_INIT} \
---distplan=${DISTPAN} \
-2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
+--distplan=${DISTPLAN} \
+--train_step=${TRAIN_STEP} \
+2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh
new file mode 100644
index 000000000..6079d5ed6
--- /dev/null
+++ b/examples/language/gpt/gemini/test_ci.sh
@@ -0,0 +1,35 @@
+set -x
+$(cd `dirname $0`;pwd)
+export TRAIN_STEP=4
+
+for MODEL_TYPE in "gpt2_medium"; do
+ for DISTPLAN in "colossalai"; do
+ for BATCH_SIZE in 2; do
+ for GPUNUM in 1 4; do
+ for TPDEGREE in 1 2; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ for PLACEMENT in "cpu" "auto"; do
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
+ bash ./run_gemini.sh
+ done
+ done
+ done
+ done
+ done
+
+ for DISTPLAN in "zero1" "zero2"; do
+ for BATCH_SIZE in 2; do
+ for GPUNUM in 1 4; do
+ for TPDEGREE in 1; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
+ bash ./run_gemini.sh
+ done
+ done
+ done
+ done
+done
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 29f8c8ef1..285706596 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -65,6 +65,13 @@ def parse_args():
default="gpt2_medium",
help="model model scale",
)
+ parser.add_argument(
+ "--train_step",
+ type=int,
+ default=10,
+ help="training iterations for test",
+ )
+
args = parser.parse_args()
return args
@@ -180,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP
-def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
+def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model,
+ strict_ddp_mode=ddp_flag,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
hidden_dim=model.config.n_embd,
- search_range_mb=64)
+ search_range_mb=128)
# configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
@@ -236,7 +244,8 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
- NUM_STEPS = 10
+ NUM_STEPS = args.train_step
+
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
@@ -270,14 +279,17 @@ def main():
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
- tensor_parallelize(model, tp_pg)
+ # You should notice that v0.1.10 is not compatible with TP degree > 1
+ if args.tp_degree > 1:
+ tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
- model, optimizer = build_gemini(model, tp_pg, args.placement)
+ model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
+ assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()
if args.distplan.startswith("torch"):
@@ -288,12 +300,17 @@ def main():
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
- partition_flag = args.distplan == "zero2"
+ model = model.half()
+ partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
- optimizer = LowLevelZeroOptimizer(optimizer,
- overlap_communication=True,
- partition_grad=partition_flag,
- verbose=True)
+
+ optimizer = LowLevelZeroOptimizer(
+ optimizer,
+ reduce_bucket_size=12 * 1024 * 1024,
+ overlap_communication=True,
+ partition_grad=partition_flag,
+ verbose=True,
+ )
# model is shared after TP
numel = get_model_size(model)
diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt
index e1f131468..ef58bb76b 100644
--- a/examples/language/gpt/requirements.txt
+++ b/examples/language/gpt/requirements.txt
@@ -1 +1,2 @@
transformers >= 4.23
+colossalai
diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh
index ad0cfa325..d67c17229 100644
--- a/examples/language/gpt/test_ci.sh
+++ b/examples/language/gpt/test_ci.sh
@@ -1,16 +1,2 @@
-pip install -r requirements.txt
-
-# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
-export DISTPAN="colossalai"
-
-# The following options only valid when DISTPAN="colossalai"
-export TPDEGREE=2
-export GPUNUM=4
-export PLACEMENT='cpu'
-export USE_SHARD_INIT=False
-export BATCH_SIZE=8
-export MODEL_TYPE="gpt2_medium"
-
-
-mkdir -p logs
-torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
+set -x
+cd gemini && bash test_ci.sh
diff --git a/examples/language/gpt/titans/LICENSE b/examples/language/gpt/titans/LICENSE
new file mode 100644
index 000000000..261eeb9e9
--- /dev/null
+++ b/examples/language/gpt/titans/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/examples/language/gpt/titans/README.md b/examples/language/gpt/titans/README.md
new file mode 100644
index 000000000..fe1854c9f
--- /dev/null
+++ b/examples/language/gpt/titans/README.md
@@ -0,0 +1,48 @@
+# Run GPT With Colossal-AI
+
+## How to Prepare Webtext Dataset
+
+You can download the preprocessed sample dataset for this demo via our [Google Drive sharing link](https://drive.google.com/file/d/1QKI6k-e2gJ7XgS8yIpgPPiMmwiBP_BPE/view?usp=sharing).
+
+
+You can also avoid dataset preparation by using `--use_dummy_dataset` during running.
+
+## Run this Demo
+
+Use the following commands to install prerequisites.
+
+```bash
+# assuming using cuda 11.3
+pip install -r requirements.txt
+```
+
+Use the following commands to execute training.
+
+```Bash
+#!/usr/bin/env sh
+# if you want to use real dataset, then remove --use_dummy_dataset
+# export DATA=/path/to/small-gpt-dataset.json'
+
+# run on a single node
+colossalai run --nproc_per_node= train_gpt.py --config configs/ --from_torch --use_dummy_dataset
+
+# run on multiple nodes with slurm
+colossalai run --nproc_per_node= \
+ --master_addr \
+ --master_port \
+ --hosts \
+ train_gpt.py \
+ --config configs/ \
+ --from_torch \
+ --use_dummy_dataset
+
+# run on multiple nodes with slurm
+srun python \
+ train_gpt.py \
+ --config configs/ \
+ --host \
+ --use_dummy_dataset
+
+```
+
+You can set the `` to any file in the `configs` folder. To simply get it running, you can start with `gpt_small_zero3_pp1d.py` on a single node first. You can view the explanations in the config file regarding how to change the parallel setting.
diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
new file mode 100644
index 000000000..7bf533039
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT2_small_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 8
+NUM_EPOCHS = 10
+SEQ_LEN = 1024
+NUM_MICRO_BATCHES = 4
+HIDDEN_SIZE = 768
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+ optimizer_config=dict(initial_scale=2**5))
+
+optimizer = dict(
+ type=HybridAdam,
+ lr=0.000015,
+ weight_decay=1e-2,
+)
+
+model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'),
+)
diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
new file mode 100644
index 000000000..9f9816b30
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT3_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 192
+NUM_EPOCHS = 60
+SEQ_LEN = 2048
+NUM_MICRO_BATCHES = 192
+HIDDEN_SIZE = 12288
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+ optimizer_config=dict(initial_scale=2**16))
+
+optimizer = dict(
+ type=HybridAdam,
+ lr=0.00015,
+ weight_decay=1e-2,
+)
+
+model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None
+)
diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py
new file mode 100644
index 000000000..64f5944a9
--- /dev/null
+++ b/examples/language/gpt/titans/dataset/webtext.py
@@ -0,0 +1,43 @@
+import json
+import os
+from typing import Optional
+
+import torch
+from torch.utils.data import Dataset
+from transformers import GPT2Tokenizer
+
+from colossalai.registry import DATASETS
+
+
+@DATASETS.register_module
+class WebtextDataset(Dataset):
+
+ def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
+ super().__init__()
+ if path is not None:
+ root = os.path.dirname(path)
+ encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
+ if os.path.isfile(encoded_data_cache_path):
+ seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
+ if seq_len_ == seq_len:
+ self.data = data
+ self.attention_mask = attention_mask
+ return
+ raw_data = []
+ with open(path) as f:
+ for line in f.readlines():
+ raw_data.append(json.loads(line)['text'])
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.unk_token
+ encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
+ self.data = encoded_data['input_ids']
+ self.attention_mask = encoded_data['attention_mask']
+ else:
+ self.data = torch.randint(0, 50257, (10240, seq_len))
+ self.attention_mask = torch.ones_like(self.data)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
diff --git a/examples/language/gpt/titans/model/__init__.py b/examples/language/gpt/titans/model/__init__.py
new file mode 100644
index 000000000..eec48ef89
--- /dev/null
+++ b/examples/language/gpt/titans/model/__init__.py
@@ -0,0 +1,3 @@
+from .embed import vocab_parallel_cross_entropy
+from .gpt1d import *
+from .pipeline_gpt1d import *
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
new file mode 100644
index 000000000..6369b9f8c
--- /dev/null
+++ b/examples/language/gpt/titans/model/embed.py
@@ -0,0 +1,599 @@
+import torch
+import torch.nn.init as init
+from torch import Tensor
+from torch import distributed as dist
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.parameter import Parameter
+
+from colossalai.context import ParallelMode, seed
+from colossalai.core import global_context as gpc
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
+from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
+from colossalai.nn.layer.utils import divide
+from colossalai.registry import LAYERS, LOSSES, MODELS
+from colossalai.utils import get_current_device
+
+
+class VocabParallelEmbedding(torch.nn.Module):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ num_tokentypes=0,
+ dtype=torch.float):
+ super(VocabParallelEmbedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.num_tokentypes = num_tokentypes
+
+ # Word embeddings (parallel).
+ self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
+ self._word_embeddings_key = 'word_embeddings'
+
+ # Position embedding (serial).
+ self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
+ self._position_embeddings_key = 'position_embeddings'
+ # Initialize the position embeddings.
+ # self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def zero_parameters(self):
+ """Zero out all parameters in embedding."""
+ self.word_embeddings.weight.data.fill_(0)
+ self.word_embeddings.weight.shared = True
+ self.position_embeddings.weight.data.fill_(0)
+ self.position_embeddings.weight.shared = True
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings.weight.data.fill_(0)
+ self.tokentype_embeddings.weight.shared = True
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+ # Embeddings.
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ words_embeddings = self.word_embeddings(input_ids)
+
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+ if position_ids is None:
+ position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = words_embeddings + position_embeddings
+
+ # Dropout.
+ with seed(ParallelMode.TENSOR):
+ embeddings = self.embedding_dropout(embeddings)
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it',
+ flush=True)
+
+
+class VocabParallelEmbedding1D(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
+ super(VocabParallelEmbedding1D, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ # Set the details for compatibility.
+ self.padding_idx = None
+ self.max_norm = None
+ self.norm_type = 2.
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+ self.tensor_model_parallel_size = gpc.tensor_parallel_size
+ # Divide the weight matrix along the vocabulary dimension.
+ self.vocab_start_index, self.vocab_end_index = \
+ VocabUtility.vocab_range_from_global_vocab_size(
+ self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
+ self.tensor_model_parallel_size)
+ self.num_embeddings_per_partition = self.vocab_end_index - \
+ self.vocab_start_index
+
+ # Allocate weights and initialize.
+ factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
+ init.uniform_(self.weight, -1, 1)
+
+ def forward(self, input_):
+ if self.tensor_model_parallel_size > 1:
+ # Build the mask.
+ input_mask = (input_ < self.vocab_start_index) | \
+ (input_ >= self.vocab_end_index)
+ # Mask the input.
+ masked_input = input_.clone() - self.vocab_start_index
+ masked_input[input_mask] = 0
+ else:
+ masked_input = input_
+ # Get the embeddings.
+ output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.sparse)
+ # Mask the output embedding.
+ if self.tensor_model_parallel_size > 1:
+ output_parallel[input_mask, :] = 0.0
+ # Reduce across all the model parallel GPUs.
+ output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
+ return output
+
+
+@LOSSES.register_module
+class vocab_parallel_cross_entropy(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, vocab_parallel_logits, target):
+ """Helper function for the cross entropy."""
+ vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
+ target = target[..., 1:].contiguous()
+ return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
+ target.view(-1))
+
+
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, vocab_parallel_logits, target):
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(logits_max,
+ op=torch.distributed.ReduceOp.MAX,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indices
+ get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+ world_size = gpc.tensor_parallel_size
+ vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(predicted_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(sum_exp_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+ loss = loss.mean()
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as their gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None
+
+
+class VocabUtility:
+ """Split the vocabulary into `world_size` chunks amd return the
+ first and last index of the vocabulary belonging to the `rank`
+ partition: Note that indices in [fist, last)"""
+
+ @staticmethod
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
+ index_f = rank * per_partition_vocab_size
+ index_l = index_f + per_partition_vocab_size
+ return index_f, index_l
+
+ @staticmethod
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
+
+
+class VocabParallelGPTLMHead1D(ParallelLayer):
+ """
+ Language model head that shares the same parameters with the embedding matrix.
+ """
+
+ def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
+ super().__init__()
+ if embed is not None:
+ self.head = embed
+ else:
+ self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = reduce_grad(x, ParallelMode.PARALLEL_1D)
+ x = F.linear(x, self.head.weight)
+ return x
+
+
+###################################
+
+
+class HiddenParallelEmbedding(torch.nn.Module):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ dtype=torch.float,
+ padding_idx: int = 0,
+ num_tokentypes=0,
+ ):
+ super(HiddenParallelEmbedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.num_tokentypes = num_tokentypes
+
+ # Word embeddings (parallel).
+ self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+ self._word_embeddings_key = 'word_embeddings'
+
+ # Position embedding (serial).
+ self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
+ self._position_embeddings_key = 'position_embeddings'
+ # Initialize the position embeddings.
+ # self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def zero_parameters(self):
+ """Zero out all parameters in embedding."""
+ self.word_embeddings.weight.data.fill_(0)
+ self.word_embeddings.weight.shared = True
+ self.position_embeddings.weight.data.fill_(0)
+ self.position_embeddings.weight.shared = True
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings.weight.data.fill_(0)
+ self.tokentype_embeddings.weight.shared = True
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ words_embeddings = self.word_embeddings(input_ids)
+
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+ if position_ids is None:
+ position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = words_embeddings + position_embeddings
+
+ # Dropout.
+ with seed(ParallelMode.TENSOR):
+ embeddings = self.embedding_dropout(embeddings)
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it',
+ flush=True)
+
+
+class HiddenParallelEmbedding1D(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
+ super(HiddenParallelEmbedding1D, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
+ # Set the details for compatibility.
+ self.padding_idx = padding_idx
+ self.max_norm = None
+ self.norm_type = 2.
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+
+ # Allocate weights and initialize.
+ factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+ self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
+ init.uniform_(self.weight, -1, 1)
+
+ def forward(self, input_):
+
+ # Get the embeddings.
+ output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.sparse)
+
+ # Reduce across all the model parallel GPUs.
+ output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
+ return output
+
+
+@LAYERS.register_module
+class HiddenParallelGPTLMHead1D(ParallelLayer):
+ """
+ Language model head that shares the same parameters with the embedding matrix.
+ """
+
+ def __init__(
+ self,
+ embed=None,
+ embed_dim=None,
+ vocab_size=None,
+ dtype=None,
+ ):
+ super().__init__()
+ if embed is not None:
+ self.head = embed
+ self.synced_embed = True
+ else:
+ # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+ # (hidden_size/q, vocab_size)
+ self.synced_embed = False
+ self.head = Linear1D_Row(in_features=embed_dim,
+ out_features=vocab_size,
+ bias=False,
+ dtype=dtype,
+ parallel_input=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.synced_embed:
+ x = F.linear(x, self.head.weight)
+ else:
+ x = self.head(x)
+
+ return x
diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py
new file mode 100644
index 000000000..2edd03606
--- /dev/null
+++ b/examples/language/gpt/titans/model/gpt1d.py
@@ -0,0 +1,349 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import math
+
+import torch
+from torch import Tensor
+from torch import nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.core import global_context as gpc
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.utils import ACT2FN, divide
+from colossalai.utils import checkpoint
+from colossalai.utils.activation_checkpoint import checkpoint
+
+__all__ = [
+ 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
+]
+
+
+class GPTMLP1D(ParallelLayer):
+
+ def __init__(
+ self,
+ in_features: int,
+ mlp_ratio: int,
+ act_func: str = 'gelu',
+ dropout_prob: float = 0.,
+ dtype=None,
+ checkpoint: bool = False,
+ skip_bias_add: bool = False,
+ ):
+ super().__init__()
+
+ self.in_features = in_features
+ self.mlp_ratio = mlp_ratio
+ self.checkpoint = checkpoint
+ self.skip_bias_add = skip_bias_add
+
+ self.act = ACT2FN[act_func]
+ skip_dense_1_add_bias = False
+
+ # Project to mlp_ratio * h.
+ self.dense_1 = Linear1D_Col(
+ self.in_features,
+ int(self.mlp_ratio * self.in_features),
+ dtype=dtype,
+ gather_output=False,
+ skip_bias_add=skip_dense_1_add_bias,
+ )
+
+ # Project back to h.
+ self.dense_2 = Linear1D_Row(
+ int(self.mlp_ratio * self.in_features),
+ self.in_features,
+ dtype=dtype,
+ parallel_input=True,
+ )
+
+ self.dropout = col_nn.Dropout(dropout_prob)
+
+ def _forward(self, hidden_states: Tensor) -> Tensor:
+ intermediate_output = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output)
+
+ output = self.dense_2(intermediate_output)
+ output = self.dropout(output)
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, False, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
+
+class GenericGPTSelfAttention1D(ParallelLayer):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.attention_head_size = divide(hidden_size, num_attention_heads)
+ self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
+ self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
+ self.checkpoint = checkpoint
+ self.query_key_value = Linear1D_Col(
+ hidden_size,
+ 3 * hidden_size,
+ dtype=dtype,
+ )
+ self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
+ self.dense = Linear1D_Row(
+ hidden_size,
+ hidden_size,
+ dtype=dtype,
+ parallel_input=True,
+ )
+ self.dropout = col_nn.Dropout(hidden_dropout_prob)
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ raise NotImplementedError
+
+ def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ query_key_value = self.query_key_value(hidden_states)
+ new_qkv_shape = query_key_value.shape[:-1] + \
+ (self.num_attention_heads_per_partition, 3 * self.attention_head_size)
+ query_key_value = query_key_value.view(new_qkv_shape)
+ query_key_value = query_key_value.permute((0, 2, 1, 3))
+ query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
+
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
+
+ attention_scores = attention_scores.type(value_layer.dtype)
+
+ attention_probs = self.attention_dropout(attention_scores)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.transpose(1, 2)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+ output = self.dense(context_layer)
+ output = self.dropout(output)
+
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ return checkpoint(self._forward, False, hidden_states, attention_mask)
+
+ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states, attention_mask)
+ else:
+ return self._forward(hidden_states, attention_mask)
+
+
+class GPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024):
+ super().__init__(hidden_size,
+ num_attention_heads,
+ attention_dropout_prob,
+ hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings)
+ self.softmax = nn.Softmax(dim=-1)
+ max_positions = max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions),
+ dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # causal mask
+ query_length, key_length = query_layer.size(-2), key_layer.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
+ attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
+ if attention_mask is not None:
+ # Apply the attention mask
+ attention_scores = attention_scores + attention_mask
+ attention_scores = self.softmax(attention_scores)
+ return attention_scores
+
+
+class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024):
+ super().__init__(hidden_size,
+ num_attention_heads,
+ attention_dropout_prob,
+ hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings)
+ self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
+ input_in_bf16=False,
+ attn_mask_type=AttnMaskType.causal,
+ scaled_masked_softmax_fusion=True,
+ mask_func=None,
+ softmax_in_fp32=True,
+ scale=math.sqrt(self.attention_head_size))
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ return self.softmax(attention_scores, attention_mask)
+
+
+class GenericGPTTransformerLayer1D(ParallelLayer):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4.0,
+ attention_dropout_prob: float = 0.,
+ hidden_dropout_prob: float = 0.,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ attention=None,
+ layer_norm=None):
+ super().__init__()
+ self.checkpoint = checkpoint
+ self.dtype = dtype
+ self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+ self.apply_post_layer_norm = apply_post_layer_norm
+ self.attention = attention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ max_position_embeddings=max_position_embeddings,
+ checkpoint=False,
+ )
+
+ self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+ self.mlp = GPTMLP1D(
+ in_features=hidden_size,
+ dropout_prob=hidden_dropout_prob,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ dtype=dtype,
+ checkpoint=False,
+ )
+
+ def _forward(self, hidden_states, attention_mask) -> Tensor:
+ if not self.apply_post_layer_norm:
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ if self.apply_post_layer_norm:
+ residual = hidden_states
+ attention_output = self.attention(hidden_states, attention_mask)
+ hidden_states = residual + attention_output
+
+ if not self.apply_post_layer_norm:
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ if self.apply_post_layer_norm:
+ residual = hidden_states
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + feed_forward_hidden_states
+
+ output = (hidden_states, attention_mask)
+ return output
+
+ def forward(self, hidden_states, attention_mask):
+ if self.checkpoint:
+ return checkpoint(self._forward, False, hidden_states, attention_mask)
+ else:
+ return self._forward(hidden_states, attention_mask)
+
+
+class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4,
+ attention_dropout_prob: float = 0,
+ hidden_dropout_prob: float = 0,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 0.00001,
+ apply_post_layer_norm: bool = False):
+ attention = GPTSelfAttention1D
+ layer_norm = nn.LayerNorm
+ super().__init__(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm,
+ attention=attention,
+ layer_norm=layer_norm)
+
+
+class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4,
+ attention_dropout_prob: float = 0,
+ hidden_dropout_prob: float = 0,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 0.00001,
+ apply_post_layer_norm: bool = False):
+ attention = FusedGPTSelfAttention1D
+ layer_norm = kernel.LayerNorm
+ super().__init__(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm,
+ attention=attention,
+ layer_norm=layer_norm)
diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py
new file mode 100644
index 000000000..30180285b
--- /dev/null
+++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py
@@ -0,0 +1,322 @@
+import inspect
+
+# import model_zoo.gpt.gpt as col_gpt
+import titans.model.gpt.gpt as col_gpt
+import torch
+import torch.nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.pipeline.utils import partition_uniform
+
+from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
+from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
+
+__all__ = [
+ 'GPT2_small_pipeline_1D',
+ 'GPT2_exlarge_pipeline_1D',
+ 'GPT3_pipeline_1D',
+ 'GPT2_exlarge_pipeline_hybrid',
+ 'GPT2_small_pipeline_hybrid',
+ 'GPT3_pipeline_hybrid',
+]
+
+
+class GenericPipelineGPT(nn.Module):
+
+ def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
+ super().__init__()
+ self.embedding = embedding
+ self.blocks = blocks
+ self.norm = norm
+ self.head = head
+ assert blocks is not None
+ if norm is not None or head is not None:
+ assert norm is not None and head is not None
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ batch_size = hidden_states.shape[0]
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = attention_mask[:, None, None, :]
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+class PipelineGPT1D(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4.0,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ embed_cls = VocabParallelEmbedding
+ head_cls = VocabParallelGPTLMHead1D
+ if embed_split_hidden:
+ embed_cls = HiddenParallelEmbedding
+ head_cls = HiddenParallelGPTLMHead1D
+ if first:
+ embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+ blocks = nn.ModuleList([
+ GPTTransformerLayer1D(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attn_drop_rate,
+ hidden_dropout_prob=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+ ])
+ if last:
+ norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+class FusedPipelineGPT1D(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4.0,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ embed_cls = VocabParallelEmbedding
+ head_cls = VocabParallelGPTLMHead1D
+ if embed_split_hidden:
+ embed_cls = HiddenParallelEmbedding
+ head_cls = HiddenParallelGPTLMHead1D
+ if first:
+ embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+ blocks = nn.ModuleList([
+ FusedGPTTransformerLayer1D(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attn_drop_rate,
+ hidden_dropout_prob=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+ ])
+ if last:
+ norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+class PipelineGPTHybrid(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ if first:
+ embedding = col_gpt.GPTEmbedding(hidden_size,
+ vocab_size,
+ max_position_embeddings,
+ dropout=embed_drop_rate,
+ dtype=dtype)
+ blocks = nn.ModuleList([
+ col_gpt.GPTBlock(hidden_size,
+ num_attention_heads,
+ mlp_ratio=mlp_ratio,
+ attention_dropout=attn_drop_rate,
+ dropout=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ activation=nn.functional.gelu) for _ in range(num_layers)
+ ])
+ if last:
+ norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ # head = col_gpt.GPTLMHead(vocab_size=vocab_size,
+ # hidden_size=hidden_size,
+ # dtype=dtype,
+ # bias=False)
+ head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+def _filter_kwargs(func, kwargs):
+ sig = inspect.signature(func)
+ return {k: v for k, v in kwargs.items() if k in sig.parameters}
+
+
+def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ else:
+ pipeline_size = 1
+ pipeline_rank = 0
+ rank = gpc.get_global_rank()
+
+ if pipeline_size > 1:
+ wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
+ else:
+ wrapper = None
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+ for start, end in parts:
+ kwargs['num_layers'] = end - start
+ kwargs['first'] = start == 0
+ kwargs['last'] = end == num_layers
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
+
+ if wrapper is not None:
+ if start == 0:
+ wrapper.register_module(chunk.embedding.word_embeddings)
+ elif end == num_layers:
+ wrapper.register_module(chunk.head)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+
+ numel = 0
+ for _, param in model.named_parameters(recurse=True):
+ numel += param.numel()
+ logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
+ return model
+
+
+def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
+ model = FusedPipelineGPT1D if fused else PipelineGPT1D
+ return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
+
+
+def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
+
+
+def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=768,
+ num_attention_heads=12,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=1600,
+ num_attention_heads=32,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
+
+
+def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=12288,
+ num_attention_heads=96,
+ checkpoint=checkpoint,
+ max_position_embeddings=2048,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=1600,
+ num_attention_heads=32,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
+
+
+def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=768,
+ num_attention_heads=12,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
+
+
+def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=12288,
+ num_attention_heads=96,
+ checkpoint=checkpoint,
+ max_position_embeddings=2048,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
diff --git a/examples/language/gpt/titans/requirements.txt b/examples/language/gpt/titans/requirements.txt
new file mode 100644
index 000000000..64ff7a4ab
--- /dev/null
+++ b/examples/language/gpt/titans/requirements.txt
@@ -0,0 +1,4 @@
+torch==1.12.1
+titans==0.0.7
+colossalai==0.2.0+torch1.12cu11.3
+-f https://release.colossalai.org
diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh
new file mode 100644
index 000000000..a1a7fc737
--- /dev/null
+++ b/examples/language/gpt/titans/run.sh
@@ -0,0 +1,3 @@
+export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
+DUMMY_DATA=--use_dummy_dataset
+colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
diff --git a/examples/language/gpt/titans/test_ci.sh b/examples/language/gpt/titans/test_ci.sh
new file mode 100644
index 000000000..7cb24c1a4
--- /dev/null
+++ b/examples/language/gpt/titans/test_ci.sh
@@ -0,0 +1 @@
+colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch --use_dummy_dataset
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
new file mode 100644
index 000000000..66225d6c8
--- /dev/null
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -0,0 +1,113 @@
+import contextlib
+import os
+
+import torch
+import torch.nn as nn
+from dataset.webtext import WebtextDataset
+from titans.model.gpt import GPTLMLoss
+
+import colossalai
+import colossalai.utils as utils
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn import LinearWarmupLR
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
+from colossalai.utils.timer import MultiTimer
+from colossalai.zero.init_ctx import ZeroInitContext
+
+
+def calc_local_model_size(model: torch.nn.Module):
+ numel_per_device = 0
+ for p in model.parameters():
+ numel_per_device += p.numel()
+ return numel_per_device
+
+
+VOCAB_SIZE = 50257
+
+
+def main():
+ parser = colossalai.get_default_parser()
+ parser.add_argument('--from_torch', default=False, action='store_true')
+ parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
+ args = parser.parse_args()
+ disable_existing_loggers()
+ if args.from_torch:
+ colossalai.launch_from_torch(config=args.config)
+ else:
+ colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
+ logger = get_dist_logger()
+
+ data_path = None if args.use_dummy_dataset else os.environ['DATA']
+ logger.info(f'Build data loader from path {data_path}', ranks=[0])
+
+ train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
+ train_dataloader = utils.get_dataloader(train_ds,
+ seed=42,
+ batch_size=gpc.config.BATCH_SIZE,
+ pin_memory=True,
+ shuffle=True,
+ drop_last=True)
+
+ logger.info('Build model', ranks=[0])
+ use_pipeline = is_using_pp()
+ use_interleaved = hasattr(gpc.config.model, 'num_chunks')
+ use_zero3 = hasattr(gpc.config, 'zero')
+ ctx = contextlib.nullcontext()
+ if use_zero3:
+ ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
+ shard_strategy=gpc.config.zero.model_config.shard_strategy,
+ shard_param=True)
+ with ctx:
+ model = gpc.config.model.pop('type')(**gpc.config.model)
+ if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
+ model = nn.ModuleList([model])
+
+ if use_zero3:
+ numel = ctx.model_numel_tensor.item()
+ else:
+ numel = calc_local_model_size(model)
+
+ tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
+ * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
+
+ criterion = getattr(gpc.config, 'loss_fn', None)
+ if criterion is not None:
+ criterion = criterion.type()
+ else:
+ criterion = GPTLMLoss()
+ logger.info('Build optimizer', ranks=[0])
+ optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
+ lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
+ engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader=train_dataloader,
+ lr_scheduler=lr_scheduler)
+ global_batch_size = gpc.config.BATCH_SIZE * \
+ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
+ logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
+ timier = MultiTimer()
+ trainer = Trainer(engine=engine, logger=logger, timer=timier)
+ hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
+ hooks.LogMetricByStepHook(),
+ hooks.LogMemoryByEpochHook(logger),
+ # hooks.LogMemoryByEpochHook(logger),
+ # hooks.LogTimingByEpochHook(timer, logger),
+ ]
+ trainer.fit(train_dataloader=train_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True,
+ return_output_label=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt
new file mode 100644
index 000000000..137a69e80
--- /dev/null
+++ b/examples/language/opt/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh
new file mode 100644
index 000000000..317f602cd
--- /dev/null
+++ b/examples/language/opt/test_ci.sh
@@ -0,0 +1,4 @@
+for GPUNUM in 2 1
+do
+env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh
+done
diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh
index 4aa868953..7a533509e 100644
--- a/examples/language/palm/run.sh
+++ b/examples/language/palm/run.sh
@@ -8,4 +8,4 @@ export PLACEMENT='cpu'
export USE_SHARD_INIT=False
export BATCH_SIZE=4
-env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
\ No newline at end of file
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh
new file mode 100644
index 000000000..f21095578
--- /dev/null
+++ b/examples/language/palm/test_ci.sh
@@ -0,0 +1,9 @@
+$(cd `dirname $0`;pwd)
+
+for BATCH_SIZE in 2
+do
+for GPUNUM in 1 4
+do
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
+done
+done
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index 7c080b7f3..2f012780d 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -1,27 +1,30 @@
import gzip
import random
+from functools import partial
+from time import time
import numpy as np
import torch
+import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
-from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
-from colossalai.nn.parallel import GeminiDDP, ZeroDDP
+from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
-NUM_BATCHES = int(1000)
+NUM_BATCHES = int(10)
+WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
@@ -63,9 +66,16 @@ def parse_args():
default=8,
help="batch size per DP group of training.",
)
+ parser.add_argument(
+ "--dummy_data",
+ type=bool,
+ default=False,
+ help="use dummy dataset.",
+ )
args = parser.parse_args()
return args
+
# helpers
def cycle(loader):
while True:
@@ -77,10 +87,22 @@ def decode_token(token):
return str(chr(max(32, token)))
+def get_tflops(model_numel, batch_size, seq_len, step_time):
+ return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
+
+
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
+def get_model_size(model: nn.Module):
+ total_numel = 0
+ for module in model.modules():
+ for p in module.parameters(recurse=False):
+ total_numel += p.numel()
+ return total_numel
+
+
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
@@ -104,6 +126,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
+
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
@@ -117,6 +140,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
+
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
@@ -143,20 +167,33 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
-
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
- raise TypeError(f"{args.distplan} is error")
+ raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
+logger = get_dist_logger()
-with gzip.open("./data/enwik8.gz") as file:
- X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
- trX, vaX = np.split(X, [int(90e6)])
- data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
+
+def generate_dataset(dummy_data: bool = False):
+ if not dummy_data:
+ with gzip.open("./data/enwik8.gz") as file:
+ X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
+ trX, vaX = np.split(X, [int(90e6)])
+ data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
+ # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
+ # print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}")
+ return data_train, data_val
+ else:
+ return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))
+
+
+data_train, data_val = generate_dataset(args.dummy_data)
+
+print("generate dataset ready!")
class TextSamplerDataset(Dataset):
@@ -188,7 +225,7 @@ if args.distplan == "colossalai":
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
- model = PaLM(num_tokens=256, dim=512, depth=8)
+ model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
@@ -205,25 +242,42 @@ else:
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
-
+# model is shared after TP
+numel = get_model_size(model)
+get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
# training
model.train()
-
+tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai":
optimizer.zero_grad()
-
+ start = time()
loss = model(next(train_loader))
+ fwd_end = time()
+ fwd_time = fwd_end - start
# loss.backward()
optimizer.backward(loss)
+ bwd_end = time()
+ bwd_time = bwd_end - fwd_end
- print(f"training loss: {loss.item()}")
+ # print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
+ optim_time = time() - bwd_end
+ step_time = time() - start
+
+ step_tflops = get_tflops_func(step_time)
+ logger.info(
+ f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
+ ranks=[0],
+ )
+ if i >= WARMUP_BATCHES:
+ tflops_list.append(step_tflops)
+
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
@@ -234,12 +288,16 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
optim.step()
optim.zero_grad()
- # TODO
- # if i % VALIDATE_EVERY == 0:
- # model.eval()
- # with torch.no_grad():
- # loss = model(next(val_loader))
- # print(f"validation loss: {loss.item()}")
+tflops_list.sort()
+median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
+logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
+
+# TODO
+# if i % VALIDATE_EVERY == 0:
+# model.eval()
+# with torch.no_grad():
+# loss = model(next(val_loader))
+# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# model.eval()
@@ -249,4 +307,4 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0])
- # print(output_str)
\ No newline at end of file
+ # print(output_str)
diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md
index bef7c8905..9c61e41cd 100644
--- a/examples/tutorial/README.md
+++ b/examples/tutorial/README.md
@@ -39,9 +39,6 @@ quickly deploy large AI model training and inference, reducing large AI model tr
- Try pre-trained OPT model weights with Colossal-AI
- Fine-tuning OPT with limited hardware using ZeRO, Gemini and parallelism
- Deploy the fine-tuned model to inference service
- - Acceleration of Stable Diffusion
- - Stable Diffusion with Lightning
- - Try Lightning Colossal-AI strategy to optimize memory and accelerate speed
## Discussion
@@ -168,26 +165,3 @@ docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-in
```bash
python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m
```
-
-## 🖼️ Accelerate Stable Diffusion with Colossal-AI
-1. Create a new environment for diffusion
-```bash
-conda env create -f environment.yaml
-conda activate ldm
-```
-2. Install Colossal-AI from our official page
-```bash
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-3. Install PyTorch Lightning compatible commit
-```bash
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-cd ..
-```
-
-4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`.
-5. Run training with CIFAR10.
-```bash
-python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml
-```
diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md
index e99a018c2..bb014b906 100644
--- a/examples/tutorial/auto_parallel/README.md
+++ b/examples/tutorial/auto_parallel/README.md
@@ -1,15 +1,45 @@
-# Auto-Parallelism with ResNet
+# Auto-Parallelism
+
+## Table of contents
+
+- [Auto-Parallelism](#auto-parallelism)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+ - [Setup](#setup)
+ - [Auto-Parallel Tutorial](#auto-parallel-tutorial)
+ - [Auto-Checkpoint Tutorial](#auto-checkpoint-tutorial)
+
+
+## 📚 Overview
+
+This tutorial folder contains a simple demo to run auto-parallelism with ResNet. Meanwhile, this diretory also contains demo scripts to run automatic activation checkpointing, but both features are still experimental for now and no guarantee that they will work for your version of Colossal-AI.
+
+## 🚀 Quick Start
+
+### Setup
+
+1. Create a conda environment
-## 🚀Quick Start
-### Auto-Parallel Tutorial
-1. Install `pulp` and `coin-or-cbc` for the solver.
```bash
-pip install pulp
+conda create -n auto python=3.8
+conda activate auto
+```
+
+2. Install `requirements` and `coin-or-cbc` for the solver.
+
+```bash
+pip install -r requirements.txt
conda install -c conda-forge coin-or-cbc
```
-2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset.
+
+
+### Auto-Parallel Tutorial
+
+Run the auto parallel resnet example with 4 GPUs with synthetic dataset.
+
```bash
-colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
+colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
```
You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training.
@@ -17,57 +47,6 @@ You should expect to the log like this. This log shows the edge cost on the comp
### Auto-Checkpoint Tutorial
-1. Stay in the `auto_parallel` folder.
-2. Install the dependencies.
-```bash
-pip install matplotlib transformers
-```
-3. Run a simple resnet50 benchmark to automatically checkpoint the model.
-```bash
-python auto_ckpt_solver_test.py --model resnet50
-```
-
-You should expect the log to be like this
-
-
-This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger.
-```bash
-python auto_ckpt_solver_test.py --model gpt2
-```
-
-4. Run a simple benchmark to find the optimal batch size for checkpointed model.
-```bash
-python auto_ckpt_batchsize_test.py
-```
-
-You can expect the log to be like
-
-
-
-## Prepare Dataset
-
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
-
-```bash
-export DATA=/path/to/data
-```
-
-## extra requirements to use autoparallel
-
-```bash
-pip install pulp
-conda install coin-or-cbc
-```
-
-## Run on 2*2 device mesh
-
-```bash
-colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
-```
-
-## Auto Checkpoint Benchmarking
We prepare two bechmarks for you to test the performance of auto checkpoint
@@ -86,21 +65,3 @@ python auto_ckpt_solver_test.py --model resnet50
# tun auto_ckpt_batchsize_test.py
python auto_ckpt_batchsize_test.py
```
-
-There are some results for your reference
-
-## Auto Checkpoint Solver Test
-
-### ResNet 50
-
-
-### GPT2 Medium
-
-
-## Auto Checkpoint Batch Size Test
-```bash
-===============test summary================
-batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
-batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
-batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
-```
diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
index e4aff13e4..15429f19c 100644
--- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
@@ -1,37 +1,12 @@
-import argparse
-import os
-from pathlib import Path
-
import torch
-from titans.utils import barrier_context
-from torch.fx import GraphModule
-from torchvision import transforms
-from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm
import colossalai
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
-from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
-from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
-from colossalai.utils import get_dataloader
-
-DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
- return parser.parse_args()
def synthesize_data():
@@ -41,82 +16,15 @@ def synthesize_data():
def main():
- args = parse_args()
colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger()
- if not args.synthetic:
- with barrier_context():
- # build dataloaders
- train_dataset = CIFAR10(root=DATA_ROOT,
- download=True,
- transform=transforms.Compose([
- transforms.RandomCrop(size=32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]))
-
- test_dataset = CIFAR10(root=DATA_ROOT,
- train=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
- ]))
-
- train_dataloader = get_dataloader(
- dataset=train_dataset,
- add_sampler=True,
- shuffle=True,
- batch_size=gpc.config.BATCH_SIZE,
- pin_memory=True,
- )
-
- test_dataloader = get_dataloader(
- dataset=test_dataset,
- add_sampler=True,
- batch_size=gpc.config.BATCH_SIZE,
- pin_memory=True,
- )
- else:
- train_dataloader, test_dataloader = None, None
-
- # initialize device mesh
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
# trace the model with meta data
- tracer = ColoTracer()
model = resnet50(num_classes=10).cuda()
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
-
- # prepare info for solver
- solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
-
- # solve the solution
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- if gpc.get_global_rank() == 0:
- for index, node in enumerate(graph.nodes):
- print(node.name, node.strategies_vector[solution[index]].name)
-
- # process the graph for distributed training ability
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
- gm = runtime_apply_pass(gm)
- gm.recompile()
+ model = autoparallelize(model, input_sample)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
@@ -127,65 +35,45 @@ def main():
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
for epoch in range(gpc.config.NUM_EPOCHS):
- gm.train()
+ model.train()
- if args.synthetic:
- # if we use synthetic data
- # we assume it only has 30 steps per epoch
- num_steps = range(30)
-
- else:
- # we use the actual number of steps for training
- num_steps = range(len(train_dataloader))
- data_iter = iter(train_dataloader)
+ # if we use synthetic data
+ # we assume it only has 10 steps per epoch
+ num_steps = range(10)
progress = tqdm(num_steps)
for _ in progress:
- if args.synthetic:
- # generate fake data
- img, label = synthesize_data()
- else:
- # get the real data
- img, label = next(data_iter)
+ # generate fake data
+ img, label = synthesize_data()
img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
- output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = model(img)
train_loss = criterion(output, label)
train_loss.backward(train_loss)
optimizer.step()
lr_scheduler.step()
# run evaluation
- gm.eval()
+ model.eval()
correct = 0
total = 0
- if args.synthetic:
- # if we use synthetic data
- # we assume it only has 10 steps for evaluation
- num_steps = range(30)
-
- else:
- # we use the actual number of steps for training
- num_steps = range(len(test_dataloader))
- data_iter = iter(test_dataloader)
+ # if we use synthetic data
+ # we assume it only has 10 steps for evaluation
+ num_steps = range(10)
progress = tqdm(num_steps)
for _ in progress:
- if args.synthetic:
- # generate fake data
- img, label = synthesize_data()
- else:
- # get the real data
- img, label = next(data_iter)
+ # generate fake data
+ img, label = synthesize_data()
img = img.cuda()
label = label.cuda()
with torch.no_grad():
- output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = model(img)
test_loss = criterion(output, label)
pred = torch.argmax(output, dim=-1)
correct += torch.sum(pred == label)
diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py
index fa14eda74..52e0abcef 100644
--- a/examples/tutorial/auto_parallel/config.py
+++ b/examples/tutorial/auto_parallel/config.py
@@ -1,2 +1,2 @@
-BATCH_SIZE = 128
-NUM_EPOCHS = 10
+BATCH_SIZE = 32
+NUM_EPOCHS = 2
diff --git a/examples/tutorial/auto_parallel/requirements.txt b/examples/tutorial/auto_parallel/requirements.txt
index 137a69e80..ce89e7c80 100644
--- a/examples/tutorial/auto_parallel/requirements.txt
+++ b/examples/tutorial/auto_parallel/requirements.txt
@@ -1,2 +1,7 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+torch
+colossalai
+titans
+pulp
+datasets
+matplotlib
+transformers
diff --git a/examples/tutorial/stable_diffusion/setup.py b/examples/tutorial/auto_parallel/setup.py
similarity index 68%
rename from examples/tutorial/stable_diffusion/setup.py
rename to examples/tutorial/auto_parallel/setup.py
index a24d54167..6e6cff32e 100644
--- a/examples/tutorial/stable_diffusion/setup.py
+++ b/examples/tutorial/auto_parallel/setup.py
@@ -1,7 +1,7 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
setup(
- name='latent-diffusion',
+ name='auto_parallel',
version='0.0.1',
description='',
packages=find_packages(),
@@ -10,4 +10,4 @@ setup(
'numpy',
'tqdm',
],
-)
\ No newline at end of file
+)
diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh
new file mode 100644
index 000000000..bf6275b67
--- /dev/null
+++ b/examples/tutorial/auto_parallel/test_ci.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+conda install -c conda-forge coin-or-cbc
+colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
diff --git a/examples/tutorial/hybrid_parallel/README.md b/examples/tutorial/hybrid_parallel/README.md
index 6f975e863..1b5e54f92 100644
--- a/examples/tutorial/hybrid_parallel/README.md
+++ b/examples/tutorial/hybrid_parallel/README.md
@@ -1,45 +1,40 @@
# Multi-dimensional Parallelism with Colossal-AI
+## Table of contents
-## 🚀Quick Start
-1. Install our model zoo.
-```bash
-pip install titans
-```
-2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag.
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
+- [Overview](#-overview)
+- [Quick Start](#-quick-start)
+
+## 📚 Overview
+
+This example lets you to quickly try out the hybrid parallelism provided by Colossal-AI.
+You can change the parameters below to try out different settings in the `config.py`.
+
+```python
+# parallel setting
+TENSOR_PARALLEL_SIZE = 2
+TENSOR_PARALLEL_MODE = '1d'
+
+parallel = dict(
+ pipeline=2,
+ tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
+)
```
-3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
+## 🚀 Quick Start
+1. Install PyTorch
-## Install Titans Model Zoo
+2. Install the dependencies.
```bash
-pip install titans
+pip install -r requirements.txt
```
-
-## Prepare Dataset
-
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
+3. Run the training scripts with synthetic data.
```bash
-export DATA=/path/to/data
-```
-
-
-## Run on 2*2 device mesh
-
-Current configuration setting on `config.py` is TP=2, PP=2.
-
-```bash
-# train with cifar10
colossalai run --nproc_per_node 4 train.py --config config.py
-
-# train with synthetic data
-colossalai run --nproc_per_node 4 train.py --config config.py -s
```
+
+4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py
index 2450ab1c7..fe9abf2f1 100644
--- a/examples/tutorial/hybrid_parallel/config.py
+++ b/examples/tutorial/hybrid_parallel/config.py
@@ -3,20 +3,20 @@ from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
-BATCH_SIZE = 256
+BATCH_SIZE = 4
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 10
-WARMUP_EPOCHS = 3
+NUM_EPOCHS = 2
+WARMUP_EPOCHS = 1
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
-HIDDEN_SIZE = 512
+HIDDEN_SIZE = 128
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
-NUM_CLASSES = 1000
+NUM_CLASSES = 10
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
diff --git a/examples/tutorial/hybrid_parallel/requirements.txt b/examples/tutorial/hybrid_parallel/requirements.txt
index 137a69e80..99b7ecfe1 100644
--- a/examples/tutorial/hybrid_parallel/requirements.txt
+++ b/examples/tutorial/hybrid_parallel/requirements.txt
@@ -1,2 +1,3 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+torch
+colossalai
+titans
diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh
new file mode 100644
index 000000000..e0dbef354
--- /dev/null
+++ b/examples/tutorial/hybrid_parallel/test_ci.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+colossalai run --nproc_per_node 4 train.py --config config.py
diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py
index 0f2a207cb..4953d5350 100644
--- a/examples/tutorial/hybrid_parallel/train.py
+++ b/examples/tutorial/hybrid_parallel/train.py
@@ -1,7 +1,6 @@
import os
import torch
-from titans.dataloader.cifar10 import build_cifar
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm
@@ -12,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
-from colossalai.utils import get_dataloader, is_using_pp
+from colossalai.utils import is_using_pp
class DummyDataloader():
@@ -42,12 +41,9 @@ class DummyDataloader():
def main():
- # initialize distributed setting
- parser = colossalai.get_default_parser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
- args = parser.parse_args()
-
# launch from torch
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
colossalai.launch_from_torch(config=args.config)
# get logger
@@ -94,15 +90,10 @@ def main():
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
- # create dataloaders
- root = os.environ.get('DATA', '../data')
- if args.synthetic:
- # if we use synthetic dataset
- # we train for 30 steps and eval for 10 steps per epoch
- train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
- test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
- else:
- train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
+ # use synthetic dataset
+ # we train for 10 steps and eval for 5 steps per epoch
+ train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
+ test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
@@ -139,6 +130,7 @@ def main():
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
+ gpc.destroy()
if __name__ == '__main__':
diff --git a/examples/tutorial/large_batch_optimizer/README.md b/examples/tutorial/large_batch_optimizer/README.md
index 20bddb383..1a17c2d87 100644
--- a/examples/tutorial/large_batch_optimizer/README.md
+++ b/examples/tutorial/large_batch_optimizer/README.md
@@ -1,31 +1,37 @@
-# Comparison of Large Batch Training Optimization
+# Large Batch Training Optimization
-## 🚀Quick Start
-Run with synthetic data
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
+## Table of contents
+
+- [Large Batch Training Optimization](#large-batch-training-optimization)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+
+## 📚 Overview
+
+This example lets you to quickly try out the large batch training optimization provided by Colossal-AI. We use synthetic dataset to go through the process, thus, you don't need to prepare any dataset. You can try out the `Lamb` and `Lars` optimizers from Colossal-AI with the following code.
+
+```python
+from colossalai.nn.optimizer import Lamb, Lars
```
+## 🚀 Quick Start
-## Prepare Dataset
+1. Install PyTorch
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
+2. Install the dependencies.
```bash
-export DATA=/path/to/data
+pip install -r requirements.txt
```
-You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command.
-
-
-## Run on 2*2 device mesh
+3. Run the training scripts with synthetic data.
```bash
-# run with cifar10
-colossalai run --nproc_per_node 4 train.py --config config.py
+# run on 4 GPUs
+# run with lars
+colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lars
-# run with synthetic dataset
-colossalai run --nproc_per_node 4 train.py --config config.py -s
+# run with lamb
+colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lamb
```
diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py
index e019154e4..2efa0ffd0 100644
--- a/examples/tutorial/large_batch_optimizer/config.py
+++ b/examples/tutorial/large_batch_optimizer/config.py
@@ -6,31 +6,11 @@ from colossalai.amp import AMP_TYPE
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 10
-WARMUP_EPOCHS = 3
+NUM_EPOCHS = 2
+WARMUP_EPOCHS = 1
# model config
-IMG_SIZE = 224
-PATCH_SIZE = 16
-HIDDEN_SIZE = 512
-DEPTH = 4
-NUM_HEADS = 4
-MLP_RATIO = 2
-NUM_CLASSES = 1000
-CHECKPOINT = False
-SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
-
-# parallel setting
-TENSOR_PARALLEL_SIZE = 2
-TENSOR_PARALLEL_MODE = '1d'
-
-parallel = dict(
- pipeline=2,
- tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
-)
+NUM_CLASSES = 10
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
-
-# pipeline config
-NUM_MICRO_BATCHES = parallel['pipeline']
diff --git a/examples/tutorial/large_batch_optimizer/requirements.txt b/examples/tutorial/large_batch_optimizer/requirements.txt
index 137a69e80..c01328775 100644
--- a/examples/tutorial/large_batch_optimizer/requirements.txt
+++ b/examples/tutorial/large_batch_optimizer/requirements.txt
@@ -1,2 +1,3 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+colossalai
+torch
+titans
diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh
new file mode 100644
index 000000000..89f426c54
--- /dev/null
+++ b/examples/tutorial/large_batch_optimizer/test_ci.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+
+# run test
+colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
+colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py
index d403c275d..35e54582f 100644
--- a/examples/tutorial/large_batch_optimizer/train.py
+++ b/examples/tutorial/large_batch_optimizer/train.py
@@ -1,19 +1,13 @@
-import os
-
import torch
-from titans.dataloader.cifar10 import build_cifar
-from titans.model.vit.vit import _create_vit_model
+import torch.nn as nn
+from torchvision.models import resnet18
from tqdm import tqdm
import colossalai
-from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
-from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lamb, Lars
-from colossalai.pipeline.pipelinable import PipelinableContext
-from colossalai.utils import get_dataloader, is_using_pp
class DummyDataloader():
@@ -45,7 +39,10 @@ class DummyDataloader():
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
+ parser.add_argument('--optimizer',
+ choices=['lars', 'lamb'],
+ help="Choose your large-batch optimizer",
+ required=True)
args = parser.parse_args()
# launch from torch
@@ -55,59 +52,22 @@ def main():
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
- if hasattr(gpc.config, 'LOG_PATH'):
- if gpc.get_global_rank() == 0:
- log_path = gpc.config.LOG_PATH
- if not os.path.exists(log_path):
- os.mkdir(log_path)
- logger.log_to_file(log_path)
+ # create synthetic dataloaders
+ train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
+ test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
- use_pipeline = is_using_pp()
-
- # create model
- model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
- patch_size=gpc.config.PATCH_SIZE,
- hidden_size=gpc.config.HIDDEN_SIZE,
- depth=gpc.config.DEPTH,
- num_heads=gpc.config.NUM_HEADS,
- mlp_ratio=gpc.config.MLP_RATIO,
- num_classes=10,
- init_method='jax',
- checkpoint=gpc.config.CHECKPOINT)
-
- if use_pipeline:
- pipelinable = PipelinableContext()
- with pipelinable:
- model = _create_vit_model(**model_kwargs)
- pipelinable.to_layer_list()
- pipelinable.policy = "uniform"
- model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
- else:
- model = _create_vit_model(**model_kwargs)
-
- # count number of parameters
- total_numel = 0
- for p in model.parameters():
- total_numel += p.numel()
- if not gpc.is_initialized(ParallelMode.PIPELINE):
- pipeline_stage = 0
- else:
- pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
-
- # create dataloaders
- root = os.environ.get('DATA', '../data/')
- if args.synthetic:
- train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
- test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
- else:
- train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
+ # build model
+ model = resnet18(num_classes=gpc.config.NUM_CLASSES)
# create loss function
- criterion = CrossEntropyLoss(label_smoothing=0.1)
+ criterion = nn.CrossEntropyLoss()
# create optimizer
- optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
+ if args.optimizer == "lars":
+ optim_cls = Lars
+ elif args.optimizer == "lamb":
+ optim_cls = Lamb
+ optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
diff --git a/examples/tutorial/sequence_parallel/README.md b/examples/tutorial/sequence_parallel/README.md
index 7058f53db..1b7c60e22 100644
--- a/examples/tutorial/sequence_parallel/README.md
+++ b/examples/tutorial/sequence_parallel/README.md
@@ -1,139 +1,56 @@
-# Sequence Parallelism with BERT
+# Sequence Parallelism
-In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
+## Table of contents
+
+- [Sequence Parallelism](#sequence-parallelism)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+ - [🏎 How to Train with Sequence Parallelism](#-how-to-train-with-sequence-parallelism)
+ - [Step 1. Configure your parameters](#step-1-configure-your-parameters)
+ - [Step 2. Invoke parallel training](#step-2-invoke-parallel-training)
+
+## 📚 Overview
+
+In this tutorial, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
-## 🚀Quick Start
-1. Run with the following command
+## 🚀 Quick Start
+
+1. Install PyTorch
+
+2. Install the dependencies.
+
+```bash
+pip install -r requirements.txt
+```
+
+3. Run with the following command
+
```bash
export PYTHONPATH=$PWD
-colossalai run --nproc_per_node 4 train.py -s
-```
-2. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
-
-## How to Prepare WikiPedia Dataset
-
-First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
-1. raw WikiPedia dataset
-2. wikipedia extractor (extract data from the raw dataset)
-3. vocabulary file
-4. preprocessing scripts (generate final data from extracted data)
-
-For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
-
-```python
-# download raw data
-mkdir data && cd ./data
-wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
-
-# install wiki extractor
-git clone https://github.com/FrankLeeeee/wikiextractor.git
-pip install ./wikiextractor
-
-# extractmodule
-wikiextractor --json enwiki-latest-pages-articles.xml.bz2
-cat text/*/* > ./corpus.json
-cd ..
-
-# download vocab file
-mkdir vocab && cd ./vocab
-wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
-cd ..
-
-# preprocess some data
-git clone https://github.com/NVIDIA/Megatron-LM.git
-cd ./Megatron-LM
-python tools/preprocess_data.py \
- --input ../data/corpus.json \
- --output-prefix my-bert \
- --vocab ../vocab/bert-large-uncased-vocab.txt \
- --dataset-impl mmap \
- --tokenizer-type BertWordPieceLowerCase \
- --split-sentences \
- --workers 24
+# run with synthetic dataset
+colossalai run --nproc_per_node 4 train.py
```
-After running the preprocessing scripts, you will obtain two files:
-1. my-bert_text_sentence.bin
-2. my-bert_text_sentence.idx
+> The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
-If you happen to encouter `index out of range` problem when running Megatron's script,
-this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
-```python
-class Encoder(object):
- def __init__(self, args):
- ...
-
- def initializer(self):
- ...
-
- def encode(self, json_line):
- data = json.loads(json_line)
- ids = {}
- for key in self.args.json_keys:
- text = data[key]
- doc_ids = []
-
- # lsg: avoid sentences which start with a punctuation
- # as it cannot be tokenized by splitter
- if len(text) > 0 and text[0] in string.punctuation:
- text = text[1:]
-
- for sentence in Encoder.splitter.tokenize(text):
- sentence_ids = Encoder.tokenizer.tokenize(sentence)
- if len(sentence_ids) > 0:
- doc_ids.append(sentence_ids)
- if len(doc_ids) > 0 and self.args.append_eod:
- doc_ids[-1].append(Encoder.tokenizer.eod)
- ids[key] = doc_ids
- return ids, len(json_line)
-```
-
-## How to Train with Sequence Parallelism
+## 🏎 How to Train with Sequence Parallelism
We provided `train.py` for you to execute training. Before invoking the script, there are several
steps to perform.
-### Step 1. Set data path and vocab path
-
-At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
-
-```python
-DATA_PATH =
-VOCAB_FILE_PATH =
-```
-
-`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
-
-For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
-
-```python
-DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
-```
-
-The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
-(e.g. bert-large-uncased-vocab.txt).
-
-### Step 3. Make Dataset Helper
-
-Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
-
-```python
-cd ./data/datasets
-make
-```
-
-### Step 3. Configure your parameters
+### Step 1. Configure your parameters
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=`.
-### Step 4. Invoke parallel training
+### Step 2. Invoke parallel training
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
machine setting.
diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py
index df0c5282f..6edf9cc2c 100644
--- a/examples/tutorial/sequence_parallel/config.py
+++ b/examples/tutorial/sequence_parallel/config.py
@@ -1,11 +1,8 @@
from colossalai.amp import AMP_TYPE
-DATA_PATH = ''
-VOCAB_FILE_PATH = ''
-
# hyper-parameters
-TRAIN_ITERS = 1000000
-DECAY_ITERS = 990000
+TRAIN_ITERS = 10
+DECAY_ITERS = 4
WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
EVAL_ITERS = 10
@@ -13,12 +10,12 @@ EVAL_INTERVAL = 10
LR = 0.0001
MIN_LR = 1e-05
WEIGHT_DECAY = 0.01
-SEQ_LENGTH = 512
+SEQ_LENGTH = 128
# BERT config
-DEPTH = 12
-NUM_ATTENTION_HEADS = 12
-HIDDEN_SIZE = 768
+DEPTH = 4
+NUM_ATTENTION_HEADS = 4
+HIDDEN_SIZE = 128
# model config
ADD_BINARY_HEAD = False
diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt
index 137a69e80..b49a94554 100644
--- a/examples/tutorial/sequence_parallel/requirements.txt
+++ b/examples/tutorial/sequence_parallel/requirements.txt
@@ -1,2 +1,2 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+colossalai
+torch
diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh
new file mode 100644
index 000000000..7bc20de3b
--- /dev/null
+++ b/examples/tutorial/sequence_parallel/test_ci.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+
+# run test
+colossalai run --nproc_per_node 4 train.py
diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py
index b92061000..a89747b58 100644
--- a/examples/tutorial/sequence_parallel/train.py
+++ b/examples/tutorial/sequence_parallel/train.py
@@ -1,9 +1,8 @@
import argparse
import torch
-from data import build_train_valid_test_data_iterators
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
-from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
+from data.dummy_dataloader import DummyDataloader
from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert
@@ -36,7 +35,7 @@ def parse_args():
def pipeline_data_process_func(stage_output, micro_batch_data):
- tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
+ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
@@ -53,36 +52,15 @@ def main():
logger = get_dist_logger()
- # build dataloader
- if not args.synthetic:
- initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
- VOCAB_SIZE = get_padded_vocab_size()
- trainloader, validloader, testloader = build_train_valid_test_data_iterators(
- train_iters=gpc.config.TRAIN_ITERS,
- global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
- eval_interval=gpc.config.EVAL_INTERVAL,
- eval_iters=gpc.config.EVAL_ITERS,
- data_prefix=[gpc.config.DATA_PATH],
- data_impl='mmap',
- splits_string='949,50,1',
- max_seq_length=gpc.config.SEQ_LENGTH,
- masked_lm_prob=0.15,
- short_seq_prob=0.1,
- seed=1234,
- skip_warmup=True,
- binary_head=False,
- )
- else:
- from data.dummy_dataloader import DummyDataloader
-
- BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
- VOCAB_SIZE = 30528
- trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
- vocab_size=VOCAB_SIZE,
- seq_length=gpc.config.SEQ_LENGTH)
- validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
- vocab_size=VOCAB_SIZE,
- seq_length=gpc.config.SEQ_LENGTH)
+ # build synthetic dataloader
+ BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
+ VOCAB_SIZE = 30528
+ trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
+ vocab_size=VOCAB_SIZE,
+ seq_length=gpc.config.SEQ_LENGTH)
+ validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
+ vocab_size=VOCAB_SIZE,
+ seq_length=gpc.config.SEQ_LENGTH)
logger.info("Dataloaders are built", ranks=[0])
diff --git a/examples/tutorial/stable_diffusion/LICENSE b/examples/tutorial/stable_diffusion/LICENSE
deleted file mode 100644
index 0e609df0d..000000000
--- a/examples/tutorial/stable_diffusion/LICENSE
+++ /dev/null
@@ -1,82 +0,0 @@
-Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
-
-CreativeML Open RAIL-M
-dated August 22, 2022
-
-Section I: PREAMBLE
-
-Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
-
-Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
-
-In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
-
-Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
-
-This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
-
-NOW THEREFORE, You and Licensor agree as follows:
-
-1. Definitions
-
-- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
-- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
-- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
-- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
-- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
-- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
-- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
-- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
-- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
-- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
-- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
-- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
-
-Section II: INTELLECTUAL PROPERTY RIGHTS
-
-Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
-
-2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
-3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
-
-Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
-
-4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
-Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
-You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
-You must cause any modified files to carry prominent notices stating that You changed the files;
-You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
-You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
-5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
-6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
-
-Section IV: OTHER PROVISIONS
-
-7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
-8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
-9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
-10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
-11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
-12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
-
-END OF TERMS AND CONDITIONS
-
-
-
-
-Attachment A
-
-Use Restrictions
-
-You agree not to use the Model or Derivatives of the Model:
-- In any way that violates any applicable national, federal, state, local or international law or regulation;
-- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
-- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
-- To generate or disseminate personal identifiable information that can be used to harm an individual;
-- To defame, disparage or otherwise harass others;
-- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
-- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
-- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
-- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
-- To provide medical advice and medical results interpretation;
-- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
diff --git a/examples/tutorial/stable_diffusion/README.md b/examples/tutorial/stable_diffusion/README.md
deleted file mode 100644
index a0ece4485..000000000
--- a/examples/tutorial/stable_diffusion/README.md
+++ /dev/null
@@ -1,149 +0,0 @@
-# Stable Diffusion with Colossal-AI
-*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
-fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
-
-We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies
-, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
-
-## 🚀Quick Start
-1. Create a new environment for diffusion
-```bash
-conda env create -f environment.yaml
-conda activate ldm
-```
-2. Install Colossal-AI from our official page
-```bash
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-3. Install PyTorch Lightning compatible commit
-```bash
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-cd ..
-```
-
-4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`.
-5. Run training with CIFAR10.
-```bash
-python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml
-```
-
-## Stable Diffusion
-[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
-model.
-Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
-Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
-this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
-
-
-
-
-
-[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB).
-
-
-
-
-
-## Requirements
-A suitable [conda](https://conda.io/) environment named `ldm` can be created
-and activated with:
-
-```
-conda env create -f environment.yaml
-conda activate ldm
-```
-
-You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
-
-```
-conda install pytorch torchvision -c pytorch
-pip install transformers==4.19.2 diffusers invisible-watermark
-pip install -e .
-```
-
-### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
-```
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-
-### Install [Lightning](https://github.com/Lightning-AI/lightning)
-We use the Sep. 2022 version with commit id as `b04a7aa`.
-```
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-```
-
-> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
-
-## Dataset
-The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
-you should the change the `data.file_path` in the `config/train_colossalai.yaml`
-
-## Training
-
-We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`
-
-For example, you can run the training from colossalai by
-```
-python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
-```
-
-- you can change the `--logdir` the save the log information and the last checkpoint
-
-### Training config
-You can change the trainging config in the yaml file
-
-- accelerator: acceleratortype, default 'gpu'
-- devices: device number used for training, default 4
-- max_epochs: max training epochs
-- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
-
-## Example
-
-### Training on cifar10
-
-We provide the finetuning example on CIFAR10 dataset
-
-You can run by config `train_colossalai_cifar10.yaml`
-```
-python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
-```
-
-
-
-## Comments
-
-- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
-, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch),
-[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion).
-Thanks for open-sourcing!
-
-- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
-
-- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch).
-
-## BibTeX
-
-```
-@article{bian2021colossal,
- title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
- author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
- journal={arXiv preprint arXiv:2110.14883},
- year={2021}
-}
-@misc{rombach2021highresolution,
- title={High-Resolution Image Synthesis with Latent Diffusion Models},
- author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
- year={2021},
- eprint={2112.10752},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
-}
-@article{dao2022flashattention,
- title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
- author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
- journal={arXiv preprint arXiv:2205.14135},
- year={2022}
-}
-```
diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml
deleted file mode 100644
index c457787dd..000000000
--- a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml
+++ /dev/null
@@ -1,116 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 64
- wrap: False
- train:
- target: ldm.data.base.Txt2ImgIterableBaseDataset
- params:
- file_path: "/data/scratch/diffuser/laion_part0/"
- world_size: 1
- rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml
deleted file mode 100644
index 63b9d1c01..000000000
--- a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml
+++ /dev/null
@@ -1,123 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: txt
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 4
- num_workers: 4
- train:
- target: ldm.data.cifar10.hf_dataset
- params:
- name: cifar10
- image_transforms:
- - target: torchvision.transforms.Resize
- params:
- size: 512
- interpolation: 3
- - target: torchvision.transforms.RandomCrop
- params:
- size: 512
- - target: torchvision.transforms.RandomHorizontalFlip
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 2
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml
deleted file mode 100644
index 90d41258f..000000000
--- a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml
+++ /dev/null
@@ -1,113 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 32
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 100 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 64
- wrap: False
- train:
- target: ldm.data.base.Txt2ImgIterableBaseDataset
- params:
- file_path: "/data/scratch/diffuser/laion_part0/"
- world_size: 1
- rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.DDPStrategy
- params:
- find_unused_parameters: False
- log_every_n_steps: 2
-# max_steps: 6o
- logger: True
- default_root_dir: "/tmp/diff_log/"
- # profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml
deleted file mode 100644
index 8b5d2adfa..000000000
--- a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 32
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
- check_nan_inf: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 10000 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 32
- wrap: False
- train:
- target: ldm.data.pokemon.PokemonDataset
- # params:
- # file_path: "/data/scratch/diffuser/laion_part0/"
- # world_size: 1
- # rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
- initial_scale: 65536
- min_scale: 1
- max_scale: 65536
- # max_scale: 4294967296
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/environment.yaml b/examples/tutorial/stable_diffusion/environment.yaml
deleted file mode 100644
index 7d8aec86f..000000000
--- a/examples/tutorial/stable_diffusion/environment.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-name: ldm
-channels:
- - pytorch
- - defaults
-dependencies:
- - python=3.9.12
- - pip=20.3
- - cudatoolkit=11.3
- - pytorch=1.11.0
- - torchvision=0.12.0
- - numpy=1.19.2
- - pip:
- - albumentations==0.4.3
- - datasets
- - diffusers
- - opencv-python==4.6.0.66
- - pudb==2019.2
- - invisible-watermark
- - imageio==2.9.0
- - imageio-ffmpeg==0.4.2
- - pytorch-lightning==1.8.0
- - omegaconf==2.1.1
- - test-tube>=0.7.5
- - streamlit>=0.73.1
- - einops==0.3.0
- - torch-fidelity==0.3.0
- - transformers==4.19.2
- - torchmetrics==0.7.0
- - kornia==0.6
- - prefetch_generator
- - colossalai
- - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- - -e .
diff --git a/examples/tutorial/stable_diffusion/ldm/data/base.py b/examples/tutorial/stable_diffusion/ldm/data/base.py
deleted file mode 100644
index 4f3cd3571..000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/base.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import math
-from abc import abstractmethod
-
-import torch
-from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
-import os
-import numpy as np
-import cv2
-
-class Txt2ImgIterableBaseDataset(IterableDataset):
- '''
- Define an interface to make the IterableDatasets for text2img data chainable
- '''
- def __init__(self, file_path: str, rank, world_size):
- super().__init__()
- self.file_path = file_path
- self.folder_list = []
- self.file_list = []
- self.txt_list = []
- self.info = self._get_file_info(file_path)
- self.start = self.info['start']
- self.end = self.info['end']
- self.rank = rank
-
- self.world_size = world_size
- # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
- # self.iter_start = self.start + self.rank * self.per_worker
- # self.iter_end = min(self.iter_start + self.per_worker, self.end)
- # self.num_records = self.iter_end - self.iter_start
- # self.valid_ids = [i for i in range(self.iter_end)]
- self.num_records = self.end - self.start
- self.valid_ids = [i for i in range(self.end)]
-
- print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
-
- def __len__(self):
- # return self.iter_end - self.iter_start
- return self.end - self.start
-
- def __iter__(self):
- sample_iterator = self._sample_generator(self.start, self.end)
- # sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
- return sample_iterator
-
- def _sample_generator(self, start, end):
- for idx in range(start, end):
- file_name = self.file_list[idx]
- txt_name = self.txt_list[idx]
- f_ = open(txt_name, 'r')
- txt_ = f_.read()
- f_.close()
- image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- image = torch.from_numpy(image) / 255
- yield {"caption": txt_, "image":image}
-
-
- def _get_file_info(self, file_path):
- info = \
- {
- "start": 1,
- "end": 0,
- }
- self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
- for folder in self.folder_list:
- files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
- txts = [k.replace('jpg', 'txt') for k in files]
- self.file_list.extend(files)
- self.txt_list.extend(txts)
- info['end'] = len(self.file_list)
- # with open(file_path, 'r') as fin:
- # for _ in enumerate(fin):
- # info['end'] += 1
- # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
- return info
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py
deleted file mode 100644
index 53cd61263..000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py
+++ /dev/null
@@ -1,184 +0,0 @@
-from typing import Dict
-import numpy as np
-from omegaconf import DictConfig, ListConfig
-import torch
-from torch.utils.data import Dataset
-from pathlib import Path
-import json
-from PIL import Image
-from torchvision import transforms
-from einops import rearrange
-from ldm.util import instantiate_from_config
-from datasets import load_dataset
-
-def make_multi_folder_data(paths, caption_files=None, **kwargs):
- """Make a concat dataset from multiple folders
- Don't suport captions yet
- If paths is a list, that's ok, if it's a Dict interpret it as:
- k=folder v=n_times to repeat that
- """
- list_of_paths = []
- if isinstance(paths, (Dict, DictConfig)):
- assert caption_files is None, \
- "Caption files not yet supported for repeats"
- for folder_path, repeats in paths.items():
- list_of_paths.extend([folder_path]*repeats)
- paths = list_of_paths
-
- if caption_files is not None:
- datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
- else:
- datasets = [FolderData(p, **kwargs) for p in paths]
- return torch.utils.data.ConcatDataset(datasets)
-
-class FolderData(Dataset):
- def __init__(self,
- root_dir,
- caption_file=None,
- image_transforms=[],
- ext="jpg",
- default_caption="",
- postprocess=None,
- return_paths=False,
- ) -> None:
- """Create a dataset from a folder of images.
- If you pass in a root directory it will be searched for images
- ending in ext (ext can be a list)
- """
- self.root_dir = Path(root_dir)
- self.default_caption = default_caption
- self.return_paths = return_paths
- if isinstance(postprocess, DictConfig):
- postprocess = instantiate_from_config(postprocess)
- self.postprocess = postprocess
- if caption_file is not None:
- with open(caption_file, "rt") as f:
- ext = Path(caption_file).suffix.lower()
- if ext == ".json":
- captions = json.load(f)
- elif ext == ".jsonl":
- lines = f.readlines()
- lines = [json.loads(x) for x in lines]
- captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
- else:
- raise ValueError(f"Unrecognised format: {ext}")
- self.captions = captions
- else:
- self.captions = None
-
- if not isinstance(ext, (tuple, list, ListConfig)):
- ext = [ext]
-
- # Only used if there is no caption file
- self.paths = []
- for e in ext:
- self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
- if isinstance(image_transforms, ListConfig):
- image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
- image_transforms.extend([transforms.ToTensor(),
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
- image_transforms = transforms.Compose(image_transforms)
- self.tform = image_transforms
-
-
- def __len__(self):
- if self.captions is not None:
- return len(self.captions.keys())
- else:
- return len(self.paths)
-
- def __getitem__(self, index):
- data = {}
- if self.captions is not None:
- chosen = list(self.captions.keys())[index]
- caption = self.captions.get(chosen, None)
- if caption is None:
- caption = self.default_caption
- filename = self.root_dir/chosen
- else:
- filename = self.paths[index]
-
- if self.return_paths:
- data["path"] = str(filename)
-
- im = Image.open(filename)
- im = self.process_im(im)
- data["image"] = im
-
- if self.captions is not None:
- data["txt"] = caption
- else:
- data["txt"] = self.default_caption
-
- if self.postprocess is not None:
- data = self.postprocess(data)
-
- return data
-
- def process_im(self, im):
- im = im.convert("RGB")
- return self.tform(im)
-
-def hf_dataset(
- name,
- image_transforms=[],
- image_column="img",
- label_column="label",
- text_column="txt",
- split='train',
- image_key='image',
- caption_key='txt',
- ):
- """Make huggingface dataset with appropriate list of transforms applied
- """
- ds = load_dataset(name, split=split)
- image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
- image_transforms.extend([transforms.ToTensor(),
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
- tform = transforms.Compose(image_transforms)
-
- assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
- assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}"
-
- def pre_process(examples):
- processed = {}
- processed[image_key] = [tform(im) for im in examples[image_column]]
-
- label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
-
- processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]
-
- return processed
-
- ds.set_transform(pre_process)
- return ds
-
-class TextOnly(Dataset):
- def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
- """Returns only captions with dummy images"""
- self.output_size = output_size
- self.image_key = image_key
- self.caption_key = caption_key
- if isinstance(captions, Path):
- self.captions = self._load_caption_file(captions)
- else:
- self.captions = captions
-
- if n_gpus > 1:
- # hack to make sure that all the captions appear on each gpu
- repeated = [n_gpus*[x] for x in self.captions]
- self.captions = []
- [self.captions.extend(x) for x in repeated]
-
- def __len__(self):
- return len(self.captions)
-
- def __getitem__(self, index):
- dummy_im = torch.zeros(3, self.output_size, self.output_size)
- dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
- return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
-
- def _load_caption_file(self, filename):
- with open(filename, 'rt') as f:
- captions = f.readlines()
- return [x.strip('\n') for x in captions]
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py
deleted file mode 100644
index 1c473f9c6..000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py
+++ /dev/null
@@ -1,394 +0,0 @@
-import os, yaml, pickle, shutil, tarfile, glob
-import cv2
-import albumentations
-import PIL
-import numpy as np
-import torchvision.transforms.functional as TF
-from omegaconf import OmegaConf
-from functools import partial
-from PIL import Image
-from tqdm import tqdm
-from torch.utils.data import Dataset, Subset
-
-import taming.data.utils as tdu
-from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
-from taming.data.imagenet import ImagePaths
-
-from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
-
-
-def synset2idx(path_to_yaml="data/index_synset.yaml"):
- with open(path_to_yaml) as f:
- di2s = yaml.load(f)
- return dict((v,k) for k,v in di2s.items())
-
-
-class ImageNetBase(Dataset):
- def __init__(self, config=None):
- self.config = config or OmegaConf.create()
- if not type(self.config)==dict:
- self.config = OmegaConf.to_container(self.config)
- self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
- self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
- self._prepare()
- self._prepare_synset_to_human()
- self._prepare_idx_to_synset()
- self._prepare_human_to_integer_label()
- self._load()
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, i):
- return self.data[i]
-
- def _prepare(self):
- raise NotImplementedError()
-
- def _filter_relpaths(self, relpaths):
- ignore = set([
- "n06596364_9591.JPEG",
- ])
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
- if "sub_indices" in self.config:
- indices = str_to_indices(self.config["sub_indices"])
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
- self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
- files = []
- for rpath in relpaths:
- syn = rpath.split("/")[0]
- if syn in synsets:
- files.append(rpath)
- return files
- else:
- return relpaths
-
- def _prepare_synset_to_human(self):
- SIZE = 2655750
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
- self.human_dict = os.path.join(self.root, "synset_human.txt")
- if (not os.path.exists(self.human_dict) or
- not os.path.getsize(self.human_dict)==SIZE):
- download(URL, self.human_dict)
-
- def _prepare_idx_to_synset(self):
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
- if (not os.path.exists(self.idx2syn)):
- download(URL, self.idx2syn)
-
- def _prepare_human_to_integer_label(self):
- URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
- self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
- if (not os.path.exists(self.human2integer)):
- download(URL, self.human2integer)
- with open(self.human2integer, "r") as f:
- lines = f.read().splitlines()
- assert len(lines) == 1000
- self.human2integer_dict = dict()
- for line in lines:
- value, key = line.split(":")
- self.human2integer_dict[key] = int(value)
-
- def _load(self):
- with open(self.txt_filelist, "r") as f:
- self.relpaths = f.read().splitlines()
- l1 = len(self.relpaths)
- self.relpaths = self._filter_relpaths(self.relpaths)
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
-
- self.synsets = [p.split("/")[0] for p in self.relpaths]
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
-
- unique_synsets = np.unique(self.synsets)
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
- if not self.keep_orig_class_label:
- self.class_labels = [class_dict[s] for s in self.synsets]
- else:
- self.class_labels = [self.synset2idx[s] for s in self.synsets]
-
- with open(self.human_dict, "r") as f:
- human_dict = f.read().splitlines()
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
-
- self.human_labels = [human_dict[s] for s in self.synsets]
-
- labels = {
- "relpath": np.array(self.relpaths),
- "synsets": np.array(self.synsets),
- "class_label": np.array(self.class_labels),
- "human_label": np.array(self.human_labels),
- }
-
- if self.process_images:
- self.size = retrieve(self.config, "size", default=256)
- self.data = ImagePaths(self.abspaths,
- labels=labels,
- size=self.size,
- random_crop=self.random_crop,
- )
- else:
- self.data = self.abspaths
-
-
-class ImageNetTrain(ImageNetBase):
- NAME = "ILSVRC2012_train"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
- FILES = [
- "ILSVRC2012_img_train.tar",
- ]
- SIZES = [
- 147897477120,
- ]
-
- def __init__(self, process_images=True, data_root=None, **kwargs):
- self.process_images = process_images
- self.data_root = data_root
- super().__init__(**kwargs)
-
- def _prepare(self):
- if self.data_root:
- self.root = os.path.join(self.data_root, self.NAME)
- else:
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
-
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 1281167
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
- default=True)
- if not tdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
-
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
-
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
-
- print("Extracting sub-tars.")
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
- for subpath in tqdm(subpaths):
- subdir = subpath[:-len(".tar")]
- os.makedirs(subdir, exist_ok=True)
- with tarfile.open(subpath, "r:") as tar:
- tar.extractall(path=subdir)
-
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
-
- tdu.mark_prepared(self.root)
-
-
-class ImageNetValidation(ImageNetBase):
- NAME = "ILSVRC2012_validation"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
- FILES = [
- "ILSVRC2012_img_val.tar",
- "validation_synset.txt",
- ]
- SIZES = [
- 6744924160,
- 1950000,
- ]
-
- def __init__(self, process_images=True, data_root=None, **kwargs):
- self.data_root = data_root
- self.process_images = process_images
- super().__init__(**kwargs)
-
- def _prepare(self):
- if self.data_root:
- self.root = os.path.join(self.data_root, self.NAME)
- else:
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 50000
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
- default=False)
- if not tdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
-
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
-
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
-
- vspath = os.path.join(self.root, self.FILES[1])
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
- download(self.VS_URL, vspath)
-
- with open(vspath, "r") as f:
- synset_dict = f.read().splitlines()
- synset_dict = dict(line.split() for line in synset_dict)
-
- print("Reorganizing into synset folders")
- synsets = np.unique(list(synset_dict.values()))
- for s in synsets:
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
- for k, v in synset_dict.items():
- src = os.path.join(datadir, k)
- dst = os.path.join(datadir, v)
- shutil.move(src, dst)
-
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
-
- tdu.mark_prepared(self.root)
-
-
-
-class ImageNetSR(Dataset):
- def __init__(self, size=None,
- degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
- random_crop=True):
- """
- Imagenet Superresolution Dataloader
- Performs following ops in order:
- 1. crops a crop of size s from image either as random or center crop
- 2. resizes crop to size with cv2.area_interpolation
- 3. degrades resized crop with degradation_fn
-
- :param size: resizing to size after cropping
- :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
- :param downscale_f: Low Resolution Downsample factor
- :param min_crop_f: determines crop size s,
- where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
- :param max_crop_f: ""
- :param data_root:
- :param random_crop:
- """
- self.base = self.get_base()
- assert size
- assert (size / downscale_f).is_integer()
- self.size = size
- self.LR_size = int(size / downscale_f)
- self.min_crop_f = min_crop_f
- self.max_crop_f = max_crop_f
- assert(max_crop_f <= 1.)
- self.center_crop = not random_crop
-
- self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
-
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
-
- if degradation == "bsrgan":
- self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
-
- elif degradation == "bsrgan_light":
- self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
-
- else:
- interpolation_fn = {
- "cv_nearest": cv2.INTER_NEAREST,
- "cv_bilinear": cv2.INTER_LINEAR,
- "cv_bicubic": cv2.INTER_CUBIC,
- "cv_area": cv2.INTER_AREA,
- "cv_lanczos": cv2.INTER_LANCZOS4,
- "pil_nearest": PIL.Image.NEAREST,
- "pil_bilinear": PIL.Image.BILINEAR,
- "pil_bicubic": PIL.Image.BICUBIC,
- "pil_box": PIL.Image.BOX,
- "pil_hamming": PIL.Image.HAMMING,
- "pil_lanczos": PIL.Image.LANCZOS,
- }[degradation]
-
- self.pil_interpolation = degradation.startswith("pil_")
-
- if self.pil_interpolation:
- self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
-
- else:
- self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
- interpolation=interpolation_fn)
-
- def __len__(self):
- return len(self.base)
-
- def __getitem__(self, i):
- example = self.base[i]
- image = Image.open(example["file_path_"])
-
- if not image.mode == "RGB":
- image = image.convert("RGB")
-
- image = np.array(image).astype(np.uint8)
-
- min_side_len = min(image.shape[:2])
- crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
- crop_side_len = int(crop_side_len)
-
- if self.center_crop:
- self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
-
- else:
- self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
-
- image = self.cropper(image=image)["image"]
- image = self.image_rescaler(image=image)["image"]
-
- if self.pil_interpolation:
- image_pil = PIL.Image.fromarray(image)
- LR_image = self.degradation_process(image_pil)
- LR_image = np.array(LR_image).astype(np.uint8)
-
- else:
- LR_image = self.degradation_process(image=image)["image"]
-
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
-
- return example
-
-
-class ImageNetSRTrain(ImageNetSR):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- def get_base(self):
- with open("data/imagenet_train_hr_indices.p", "rb") as f:
- indices = pickle.load(f)
- dset = ImageNetTrain(process_images=False,)
- return Subset(dset, indices)
-
-
-class ImageNetSRValidation(ImageNetSR):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- def get_base(self):
- with open("data/imagenet_val_hr_indices.p", "rb") as f:
- indices = pickle.load(f)
- dset = ImageNetValidation(process_images=False,)
- return Subset(dset, indices)
diff --git a/examples/tutorial/stable_diffusion/ldm/data/lsun.py b/examples/tutorial/stable_diffusion/ldm/data/lsun.py
deleted file mode 100644
index 6256e4571..000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/lsun.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import os
-import numpy as np
-import PIL
-from PIL import Image
-from torch.utils.data import Dataset
-from torchvision import transforms
-
-
-class LSUNBase(Dataset):
- def __init__(self,
- txt_file,
- data_root,
- size=None,
- interpolation="bicubic",
- flip_p=0.5
- ):
- self.data_paths = txt_file
- self.data_root = data_root
- with open(self.data_paths, "r") as f:
- self.image_paths = f.read().splitlines()
- self._length = len(self.image_paths)
- self.labels = {
- "relative_file_path_": [l for l in self.image_paths],
- "file_path_": [os.path.join(self.data_root, l)
- for l in self.image_paths],
- }
-
- self.size = size
- self.interpolation = {"linear": PIL.Image.LINEAR,
- "bilinear": PIL.Image.BILINEAR,
- "bicubic": PIL.Image.BICUBIC,
- "lanczos": PIL.Image.LANCZOS,
- }[interpolation]
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, i):
- example = dict((k, self.labels[k][i]) for k in self.labels)
- image = Image.open(example["file_path_"])
- if not image.mode == "RGB":
- image = image.convert("RGB")
-
- # default to score-sde preprocessing
- img = np.array(image).astype(np.uint8)
- crop = min(img.shape[0], img.shape[1])
- h, w, = img.shape[0], img.shape[1]
- img = img[(h - crop) // 2:(h + crop) // 2,
- (w - crop) // 2:(w + crop) // 2]
-
- image = Image.fromarray(img)
- if self.size is not None:
- image = image.resize((self.size, self.size), resample=self.interpolation)
-
- image = self.flip(image)
- image = np.array(image).astype(np.uint8)
- example["image"] = (image / 127.5 - 1.0).astype(np.float32)
- return example
-
-
-class LSUNChurchesTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
-
-
-class LSUNChurchesValidation(LSUNBase):
- def __init__(self, flip_p=0., **kwargs):
- super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
- flip_p=flip_p, **kwargs)
-
-
-class LSUNBedroomsTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
-
-
-class LSUNBedroomsValidation(LSUNBase):
- def __init__(self, flip_p=0.0, **kwargs):
- super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
- flip_p=flip_p, **kwargs)
-
-
-class LSUNCatsTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
-
-
-class LSUNCatsValidation(LSUNBase):
- def __init__(self, flip_p=0., **kwargs):
- super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
- flip_p=flip_p, **kwargs)
diff --git a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py
deleted file mode 100644
index be39da9ca..000000000
--- a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import numpy as np
-
-
-class LambdaWarmUpCosineScheduler:
- """
- note: use with a base_lr of 1.0
- """
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
- self.lr_warm_up_steps = warm_up_steps
- self.lr_start = lr_start
- self.lr_min = lr_min
- self.lr_max = lr_max
- self.lr_max_decay_steps = max_decay_steps
- self.last_lr = 0.
- self.verbosity_interval = verbosity_interval
-
- def schedule(self, n, **kwargs):
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
- if n < self.lr_warm_up_steps:
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
- self.last_lr = lr
- return lr
- else:
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
- t = min(t, 1.0)
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
- 1 + np.cos(t * np.pi))
- self.last_lr = lr
- return lr
-
- def __call__(self, n, **kwargs):
- return self.schedule(n,**kwargs)
-
-
-class LambdaWarmUpCosineScheduler2:
- """
- supports repeated iterations, configurable via lists
- note: use with a base_lr of 1.0.
- """
- def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
- assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
- self.lr_warm_up_steps = warm_up_steps
- self.f_start = f_start
- self.f_min = f_min
- self.f_max = f_max
- self.cycle_lengths = cycle_lengths
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
- self.last_f = 0.
- self.verbosity_interval = verbosity_interval
-
- def find_in_interval(self, n):
- interval = 0
- for cl in self.cum_cycles[1:]:
- if n <= cl:
- return interval
- interval += 1
-
- def schedule(self, n, **kwargs):
- cycle = self.find_in_interval(n)
- n = n - self.cum_cycles[cycle]
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
- f"current cycle {cycle}")
- if n < self.lr_warm_up_steps[cycle]:
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
- self.last_f = f
- return f
- else:
- t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
- t = min(t, 1.0)
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
- 1 + np.cos(t * np.pi))
- self.last_f = f
- return f
-
- def __call__(self, n, **kwargs):
- return self.schedule(n, **kwargs)
-
-
-class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
-
- def schedule(self, n, **kwargs):
- cycle = self.find_in_interval(n)
- n = n - self.cum_cycles[cycle]
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
- f"current cycle {cycle}")
-
- if n < self.lr_warm_up_steps[cycle]:
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
- self.last_f = f
- return f
- else:
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
- self.last_f = f
- return f
-
diff --git a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py
deleted file mode 100644
index 873d8b69b..000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py
+++ /dev/null
@@ -1,544 +0,0 @@
-import torch
-import pytorch_lightning as pl
-import torch.nn.functional as F
-from contextlib import contextmanager
-
-from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
-
-from ldm.modules.diffusionmodules.model import Encoder, Decoder
-from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
-
-from ldm.util import instantiate_from_config
-
-
-class VQModel(pl.LightningModule):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- batch_resize_range=None,
- scheduler_config=None,
- lr_g_factor=1.0,
- remap=None,
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
- use_ema=False
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.n_embed = n_embed
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
- remap=remap,
- sane_index_shape=sane_index_shape)
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- self.batch_resize_range = batch_resize_range
- if self.batch_resize_range is not None:
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
-
- self.use_ema = use_ema
- if self.use_ema:
- self.model_ema = LitEma(self)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- self.scheduler_config = scheduler_config
- self.lr_g_factor = lr_g_factor
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.parameters())
- self.model_ema.copy_to(self)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- print(f"Unexpected Keys: {unexpected}")
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self)
-
- def encode(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- quant, emb_loss, info = self.quantize(h)
- return quant, emb_loss, info
-
- def encode_to_prequant(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- return h
-
- def decode(self, quant):
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant)
- return dec
-
- def decode_code(self, code_b):
- quant_b = self.quantize.embed_code(code_b)
- dec = self.decode(quant_b)
- return dec
-
- def forward(self, input, return_pred_indices=False):
- quant, diff, (_,_,ind) = self.encode(input)
- dec = self.decode(quant)
- if return_pred_indices:
- return dec, diff, ind
- return dec, diff
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
- if self.batch_resize_range is not None:
- lower_size = self.batch_resize_range[0]
- upper_size = self.batch_resize_range[1]
- if self.global_step <= 4:
- # do the first few batches with max size to avoid later oom
- new_resize = upper_size
- else:
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
- if new_resize != x.shape[2]:
- x = F.interpolate(x, size=new_resize, mode="bicubic")
- x = x.detach()
- return x
-
- def training_step(self, batch, batch_idx, optimizer_idx):
- # https://github.com/pytorch/pytorch/issues/37142
- # try not to fool the heuristics
- x = self.get_input(batch, self.image_key)
- xrec, qloss, ind = self(x, return_pred_indices=True)
-
- if optimizer_idx == 0:
- # autoencode
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train",
- predicted_indices=ind)
-
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return aeloss
-
- if optimizer_idx == 1:
- # discriminator
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return discloss
-
- def validation_step(self, batch, batch_idx):
- log_dict = self._validation_step(batch, batch_idx)
- with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
- return log_dict
-
- def _validation_step(self, batch, batch_idx, suffix=""):
- x = self.get_input(batch, self.image_key)
- xrec, qloss, ind = self(x, return_pred_indices=True)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val"+suffix,
- predicted_indices=ind
- )
-
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val"+suffix,
- predicted_indices=ind
- )
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
- self.log(f"val{suffix}/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- self.log(f"val{suffix}/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
- del log_dict_ae[f"val{suffix}/rec_loss"]
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
-
- def configure_optimizers(self):
- lr_d = self.learning_rate
- lr_g = self.lr_g_factor*self.learning_rate
- print("lr_d", lr_d)
- print("lr_g", lr_g)
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr_g, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr_d, betas=(0.5, 0.9))
-
- if self.scheduler_config is not None:
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- },
- {
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- },
- ]
- return [opt_ae, opt_disc], scheduler
- return [opt_ae, opt_disc], []
-
- def get_last_layer(self):
- return self.decoder.conv_out.weight
-
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if only_inputs:
- log["inputs"] = x
- return log
- xrec, _ = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["inputs"] = x
- log["reconstructions"] = xrec
- if plot_ema:
- with self.ema_scope():
- xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
- log["reconstructions_ema"] = xrec_ema
- return log
-
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
- return x
-
-
-class VQModelInterface(VQModel):
- def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
- self.embed_dim = embed_dim
-
- def encode(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- return h
-
- def decode(self, h, force_not_quantize=False):
- # also go through quantization layer
- if not force_not_quantize:
- quant, emb_loss, info = self.quantize(h)
- else:
- quant = h
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant)
- return dec
-
-
-class AutoencoderKL(pl.LightningModule):
- def __init__(self,
- ddconfig,
- lossconfig,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- from_pretrained: str=None
- ):
- super().__init__()
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- assert ddconfig["double_z"]
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- self.embed_dim = embed_dim
- if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- from diffusers.modeling_utils import load_state_dict
- if from_pretrained is not None:
- state_dict = load_state_dict(from_pretrained)
- self._load_pretrained_model(state_dict)
-
- def _state_key_mapping(self, state_dict: dict):
- import re
- res_dict = {}
- key_list = state_dict.keys()
- key_str = " ".join(key_list)
- up_block_pattern = re.compile('upsamplers')
- p1 = re.compile('mid.block_[0-9]')
- p2 = re.compile('decoder.up.[0-9]')
- up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
- for key_, val_ in state_dict.items():
- key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
- .replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
- .replace('mid.attentions.0.key', 'mid.attn_1.k')\
- .replace('mid.attentions.0.query', 'mid.attn_1.q') \
- .replace('mid.attentions.0.value', 'mid.attn_1.v') \
- .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
- .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
- .replace('upsamplers.0', 'upsample')\
- .replace('downsamplers.0', 'downsample')\
- .replace('conv_shortcut', 'nin_shortcut')\
- .replace('conv_norm_out', 'norm_out')
-
- mid_list = re.findall(p1, key_)
- if len(mid_list) != 0:
- mid_str = mid_list[0]
- mid_id = int(mid_str[-1]) + 1
- key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
-
- up_list = re.findall(p2, key_)
- if len(up_list) != 0:
- up_str = up_list[0]
- up_id = up_blocks_count - 1 -int(up_str[-1])
- key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
- res_dict[key_] = val_
- return res_dict
-
- def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
- state_dict = self._state_key_mapping(state_dict)
- model_state_dict = self.state_dict()
- loaded_keys = [k for k in state_dict.keys()]
- expected_keys = list(model_state_dict.keys())
- original_loaded_keys = loaded_keys
- missing_keys = list(set(expected_keys) - set(loaded_keys))
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
- if state_dict is not None:
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
- state_dict,
- model_state_dict,
- original_loaded_keys,
- ignore_mismatched_sizes,
- )
- error_msgs = self._load_state_dict_into_model(state_dict)
- return missing_keys, unexpected_keys, mismatched_keys, error_msgs
-
- def _load_state_dict_into_model(self, state_dict):
- # Convert old format to new format if needed from a PyTorch state_dict
- # copy state_dict so _load_from_state_dict can modify it
- state_dict = state_dict.copy()
- error_msgs = []
-
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
- # so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix=""):
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
- module._load_from_state_dict(*args)
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(self)
-
- return error_msgs
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- self.load_state_dict(sd, strict=False)
- print(f"Restored from {path}")
-
- def encode(self, x):
- h = self.encoder(x)
- moments = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(moments)
- return posterior
-
- def decode(self, z):
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
- return dec
-
- def forward(self, input, sample_posterior=True):
- posterior = self.encode(input)
- if sample_posterior:
- z = posterior.sample()
- else:
- z = posterior.mode()
- dec = self.decode(z)
- return dec, posterior
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
- return x
-
- def training_step(self, batch, batch_idx, optimizer_idx):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
-
- if optimizer_idx == 0:
- # train encoder+decoder+logvar
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
- return aeloss
-
- if optimizer_idx == 1:
- # train the discriminator
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
-
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
- return discloss
-
- def validation_step(self, batch, batch_idx):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
- last_layer=self.get_last_layer(), split="val")
-
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
- last_layer=self.get_last_layer(), split="val")
-
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
-
- def configure_optimizers(self):
- lr = self.learning_rate
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr, betas=(0.5, 0.9))
- return [opt_ae, opt_disc], []
-
- def get_last_layer(self):
- return self.decoder.conv_out.weight
-
- @torch.no_grad()
- def log_images(self, batch, only_inputs=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if not only_inputs:
- xrec, posterior = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
- log["reconstructions"] = xrec
- log["inputs"] = x
- return log
-
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
- return x
-
-
-class IdentityFirstStage(torch.nn.Module):
- def __init__(self, *args, vq_interface=False, **kwargs):
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
- super().__init__()
-
- def encode(self, x, *args, **kwargs):
- return x
-
- def decode(self, x, *args, **kwargs):
- return x
-
- def quantize(self, x, *args, **kwargs):
- if self.vq_interface:
- return x, None, [None, None, None]
- return x
-
- def forward(self, x, *args, **kwargs):
- return x
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py
deleted file mode 100644
index 67e98b9d8..000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py
+++ /dev/null
@@ -1,267 +0,0 @@
-import os
-import torch
-import pytorch_lightning as pl
-from omegaconf import OmegaConf
-from torch.nn import functional as F
-from torch.optim import AdamW
-from torch.optim.lr_scheduler import LambdaLR
-from copy import deepcopy
-from einops import rearrange
-from glob import glob
-from natsort import natsorted
-
-from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
-from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
-
-__models__ = {
- 'class_label': EncoderUNetModel,
- 'segmentation': UNetModel
-}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-class NoisyLatentImageClassifier(pl.LightningModule):
-
- def __init__(self,
- diffusion_path,
- num_classes,
- ckpt_path=None,
- pool='attention',
- label_key=None,
- diffusion_ckpt_path=None,
- scheduler_config=None,
- weight_decay=1.e-2,
- log_steps=10,
- monitor='val/loss',
- *args,
- **kwargs):
- super().__init__(*args, **kwargs)
- self.num_classes = num_classes
- # get latest config of diffusion model
- diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
- self.diffusion_config = OmegaConf.load(diffusion_config).model
- self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
- self.load_diffusion()
-
- self.monitor = monitor
- self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
- self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
- self.log_steps = log_steps
-
- self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
- else self.diffusion_model.cond_stage_key
-
- assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
-
- if self.label_key not in __models__:
- raise NotImplementedError()
-
- self.load_classifier(ckpt_path, pool)
-
- self.scheduler_config = scheduler_config
- self.use_scheduler = self.scheduler_config is not None
- self.weight_decay = weight_decay
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- def load_diffusion(self):
- model = instantiate_from_config(self.diffusion_config)
- self.diffusion_model = model.eval()
- self.diffusion_model.train = disabled_train
- for param in self.diffusion_model.parameters():
- param.requires_grad = False
-
- def load_classifier(self, ckpt_path, pool):
- model_config = deepcopy(self.diffusion_config.params.unet_config.params)
- model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
- model_config.out_channels = self.num_classes
- if self.label_key == 'class_label':
- model_config.pool = pool
-
- self.model = __models__[self.label_key](**model_config)
- if ckpt_path is not None:
- print('#####################################################################')
- print(f'load from ckpt "{ckpt_path}"')
- print('#####################################################################')
- self.init_from_ckpt(ckpt_path)
-
- @torch.no_grad()
- def get_x_noisy(self, x, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x))
- continuous_sqrt_alpha_cumprod = None
- if self.diffusion_model.use_continuous_noise:
- continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
- # todo: make sure t+1 is correct here
-
- return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
- continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
-
- def forward(self, x_noisy, t, *args, **kwargs):
- return self.model(x_noisy, t)
-
- @torch.no_grad()
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = rearrange(x, 'b h w c -> b c h w')
- x = x.to(memory_format=torch.contiguous_format).float()
- return x
-
- @torch.no_grad()
- def get_conditioning(self, batch, k=None):
- if k is None:
- k = self.label_key
- assert k is not None, 'Needs to provide label key'
-
- targets = batch[k].to(self.device)
-
- if self.label_key == 'segmentation':
- targets = rearrange(targets, 'b h w c -> b c h w')
- for down in range(self.numd):
- h, w = targets.shape[-2:]
- targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
-
- # targets = rearrange(targets,'b c h w -> b h w c')
-
- return targets
-
- def compute_top_k(self, logits, labels, k, reduction="mean"):
- _, top_ks = torch.topk(logits, k, dim=1)
- if reduction == "mean":
- return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
- elif reduction == "none":
- return (top_ks == labels[:, None]).float().sum(dim=-1)
-
- def on_train_epoch_start(self):
- # save some memory
- self.diffusion_model.model.to('cpu')
-
- @torch.no_grad()
- def write_logs(self, loss, logits, targets):
- log_prefix = 'train' if self.training else 'val'
- log = {}
- log[f"{log_prefix}/loss"] = loss.mean()
- log[f"{log_prefix}/acc@1"] = self.compute_top_k(
- logits, targets, k=1, reduction="mean"
- )
- log[f"{log_prefix}/acc@5"] = self.compute_top_k(
- logits, targets, k=5, reduction="mean"
- )
-
- self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
- self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
- self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
-
- def shared_step(self, batch, t=None):
- x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
- targets = self.get_conditioning(batch)
- if targets.dim() == 4:
- targets = targets.argmax(dim=1)
- if t is None:
- t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
- else:
- t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
- x_noisy = self.get_x_noisy(x, t)
- logits = self(x_noisy, t)
-
- loss = F.cross_entropy(logits, targets, reduction='none')
-
- self.write_logs(loss.detach(), logits.detach(), targets.detach())
-
- loss = loss.mean()
- return loss, logits, x_noisy, targets
-
- def training_step(self, batch, batch_idx):
- loss, *_ = self.shared_step(batch)
- return loss
-
- def reset_noise_accs(self):
- self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
- range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
-
- def on_validation_start(self):
- self.reset_noise_accs()
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- loss, *_ = self.shared_step(batch)
-
- for t in self.noisy_acc:
- _, logits, _, targets = self.shared_step(batch, t)
- self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
- self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
-
- return loss
-
- def configure_optimizers(self):
- optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
-
- if self.use_scheduler:
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [optimizer], scheduler
-
- return optimizer
-
- @torch.no_grad()
- def log_images(self, batch, N=8, *args, **kwargs):
- log = dict()
- x = self.get_input(batch, self.diffusion_model.first_stage_key)
- log['inputs'] = x
-
- y = self.get_conditioning(batch)
-
- if self.label_key == 'class_label':
- y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
- log['labels'] = y
-
- if ismap(y):
- log['labels'] = self.diffusion_model.to_rgb(y)
-
- for step in range(self.log_steps):
- current_time = step * self.log_time_interval
-
- _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
-
- log[f'inputs@t{current_time}'] = x_noisy
-
- pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
- pred = rearrange(pred, 'b h w c -> b c h w')
-
- log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
-
- for key in log:
- log[key] = log[key][:N]
-
- return log
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py
deleted file mode 100644
index 91335d637..000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from functools import partial
-
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
- extract_into_tensor
-
-
-class DDIMSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
- alphas_cumprod = self.model.alphas_cumprod
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer('betas', to_torch(self.model.betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
- self.register_buffer('ddim_sigmas', ddim_sigmas)
- self.register_buffer('ddim_alphas', ddim_alphas)
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
-
- @torch.no_grad()
- def sample(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
-
- samples, intermediates = self.ddim_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def ddim_sampling(self, cond, shape,
- x_T=None, ddim_use_original_steps=False,
- callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, log_every_t=100,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- if timesteps is None:
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
-
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
- img = img_orig * mask + (1. - mask) * img
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised, temperature=temperature,
- noise_dropout=noise_dropout, score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
- img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates['x_inter'].append(img)
- intermediates['pred_x0'].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- @torch.no_grad()
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
- # fast, but does not allow for exact reconstruction
- # t serves as an index to gather the correct alphas
- if use_original_steps:
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
- else:
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
-
- if noise is None:
- noise = torch.randn_like(x0)
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
-
- @torch.no_grad()
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- use_original_steps=False):
-
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
- timesteps = timesteps[:t_start]
-
- time_range = np.flip(timesteps)
- total_steps = timesteps.shape[0]
- print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
- x_dec = x_latent
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
- return x_dec
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py
deleted file mode 100644
index 9633ec3d8..000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py
+++ /dev/null
@@ -1,1554 +0,0 @@
-import torch
-import torch.nn as nn
-import numpy as np
-import pytorch_lightning as pl
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
-from contextlib import contextmanager
-from functools import partial
-from tqdm import tqdm
-from torchvision.utils import make_grid
-
-from pytorch_lightning.utilities.rank_zero import rank_zero_only
-from pytorch_lightning.utilities import rank_zero_info
-
-from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d
-from ldm.modules.x_transformer import *
-from ldm.modules.encoders.modules import *
-
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import *
-from ldm.models.diffusion.ddim import *
-from ldm.modules.diffusionmodules.openaimodel import *
-from ldm.modules.diffusionmodules.model import *
-
-
-from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder
-
-from ldm.util import instantiate_from_config
-
-from einops import rearrange, repeat
-
-
-
-
-__conditioning_keys__ = {'concat': 'c_concat',
- 'crossattn': 'c_crossattn',
- 'adm': 'y'}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def uniform_on_device(r1, r2, shape, device):
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
-
-
-class DDPM(pl.LightningModule):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- image_size=256,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.,
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.,
- use_fp16 = True,
- ):
- super().__init__()
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
- self.parameterization = parameterization
- rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
- self.cond_stage_model = None
- self.clip_denoised = clip_denoised
- self.log_every_t = log_every_t
- self.first_stage_key = first_stage_key
- self.image_size = image_size # try conv?
- self.channels = channels
- self.use_positional_encodings = use_positional_encodings
- self.unet_config = unet_config
- self.conditioning_key = conditioning_key
- # self.model = DiffusionWrapper(unet_config, conditioning_key)
- # count_params(self.model, verbose=True)
- self.use_ema = use_ema
- # if self.use_ema:
- # self.model_ema = LitEma(self.model)
- # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- self.use_scheduler = scheduler_config is not None
- if self.use_scheduler:
- self.scheduler_config = scheduler_config
-
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
-
- if monitor is not None:
- self.monitor = monitor
- self.ckpt_path = ckpt_path
- self.ignore_keys = ignore_keys
- self.load_only_unet = load_only_unet
- self.given_betas = given_betas
- self.beta_schedule = beta_schedule
- self.timesteps = timesteps
- self.linear_start = linear_start
- self.linear_end = linear_end
- self.cosine_s = cosine_s
- # if ckpt_path is not None:
- # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
- #
- # self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
- # linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
-
- self.loss_type = loss_type
-
- self.learn_logvar = learn_logvar
- self.logvar_init = logvar_init
- # self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
- # if self.learn_logvar:
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
-
- self.use_fp16 = use_fp16
- if use_fp16:
- self.unet_config["params"].update({"use_fp16": True})
- rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"]))
- else:
- self.unet_config["params"].update({"use_fp16": False})
- rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"]))
-
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
- cosine_s=cosine_s)
- alphas = 1. - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
-
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer('betas', to_torch(betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
- self.register_buffer('posterior_mean_coef1', to_torch(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
- self.register_buffer('posterior_mean_coef2', to_torch(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
-
- if self.parameterization == "eps":
- lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
- elif self.parameterization == "x0":
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
- else:
- raise NotImplementedError("mu not supported")
- # TODO how to choose this term
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.model.parameters())
- self.model_ema.copy_to(self.model)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.model.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
- return mean, variance, log_variance
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_out = self.model(x, t)
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def p_sample_loop(self, shape, return_intermediates=False):
- device = self.betas.device
- b = shape[0]
- img = torch.randn(shape, device=device)
- intermediates = [img]
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
- clip_denoised=self.clip_denoised)
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
- intermediates.append(img)
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16, return_intermediates=False):
- image_size = self.image_size
- channels = self.channels
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
- return_intermediates=return_intermediates)
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
-
- def get_loss(self, pred, target, mean=True):
-
- if pred.isnan().any():
- print("Warning: Prediction has nan values")
- lr = self.optimizers().param_groups[0]['lr']
- # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
- print(f"lr: {lr}")
- if pred.isinf().any():
- print("Warning: Prediction has inf values")
-
- if self.use_fp16:
- target = target.half()
-
- if self.loss_type == 'l1':
- loss = (target - pred).abs()
- if mean:
- loss = loss.mean()
- elif self.loss_type == 'l2':
- if mean:
- loss = torch.nn.functional.mse_loss(target, pred)
- else:
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
- else:
- raise NotImplementedError("unknown loss type '{loss_type}'")
-
- if loss.isnan().any():
- print("Warning: loss has nan values")
- print("loss: ", loss[0][0][0])
- raise ValueError("loss has nan values")
- if loss.isinf().any():
- print("Warning: loss has inf values")
- print("loss: ", loss)
- raise ValueError("loss has inf values")
-
- return loss
-
- def p_losses(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_out = self.model(x_noisy, t)
-
- loss_dict = {}
- if self.parameterization == "eps":
- target = noise
- elif self.parameterization == "x0":
- target = x_start
- else:
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
-
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
-
- log_prefix = 'train' if self.training else 'val'
-
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
- loss_simple = loss.mean() * self.l_simple_weight
-
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
-
- loss = loss_simple + self.original_elbo_weight * loss_vlb
-
- loss_dict.update({f'{log_prefix}/loss': loss})
-
- return loss, loss_dict
-
- def forward(self, x, *args, **kwargs):
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- return self.p_losses(x, t, *args, **kwargs)
-
- def get_input(self, batch, k):
- # print("+" * 30)
- # print(batch['jpg'].shape)
- # print(len(batch['txt']))
- # print(k)
- # print("=" * 30)
- if not isinstance(batch, torch.Tensor):
- x = batch[k]
- else:
- x = batch
- if len(x.shape) == 3:
- x = x[..., None]
- x = rearrange(x, 'b h w c -> b c h w')
-
- if self.use_fp16:
- x = x.to(memory_format=torch.contiguous_format).float().half()
- else:
- x = x.to(memory_format=torch.contiguous_format).float()
-
- return x
-
- def shared_step(self, batch):
- x = self.get_input(batch, self.first_stage_key)
- loss, loss_dict = self(x)
- return loss, loss_dict
-
- def training_step(self, batch, batch_idx):
- loss, loss_dict = self.shared_step(batch)
-
- self.log_dict(loss_dict, prog_bar=True,
- logger=True, on_step=True, on_epoch=True)
-
- self.log("global_step", self.global_step,
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- if self.use_scheduler:
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- return loss
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- _, loss_dict_no_ema = self.shared_step(batch)
- with self.ema_scope():
- _, loss_dict_ema = self.shared_step(batch)
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self.model)
-
- def _get_rows_from_list(self, samples):
- n_imgs_per_row = len(samples)
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
- x = self.get_input(batch, self.first_stage_key)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- x = x.to(self.device)[:N]
- log["inputs"] = x
-
- # get diffusion row
- diffusion_row = list()
- x_start = x[:n_row]
-
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(x_start)
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- diffusion_row.append(x_noisy)
-
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
-
- log["samples"] = samples
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.learn_logvar:
- params = params + [self.logvar]
- opt = torch.optim.AdamW(params, lr=lr)
- return opt
-
-
-class LatentDiffusion(DDPM):
- """main class"""
- def __init__(self,
- first_stage_config,
- cond_stage_config,
- num_timesteps_cond=None,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- cond_stage_forward=None,
- conditioning_key=None,
- scale_factor=1.0,
- scale_by_std=False,
- use_fp16=True,
- *args, **kwargs):
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
- self.scale_by_std = scale_by_std
- assert self.num_timesteps_cond <= kwargs['timesteps']
- # for backwards compatibility after implementation of DiffusionWrapper
- if conditioning_key is None:
- conditioning_key = 'concat' if concat_mode else 'crossattn'
- if cond_stage_config == '__is_unconditional__':
- conditioning_key = None
- ckpt_path = kwargs.pop("ckpt_path", None)
- ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, use_fp16=use_fp16, *args, **kwargs)
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- try:
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
- self.num_downs = 0
- if not scale_by_std:
- self.scale_factor = scale_factor
- else:
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
- self.first_stage_config = first_stage_config
- self.cond_stage_config = cond_stage_config
- if self.use_fp16:
- self.cond_stage_config["params"].update({"use_fp16": True})
- rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
- else:
- self.cond_stage_config["params"].update({"use_fp16": False})
- rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
- # self.instantiate_first_stage(first_stage_config)
- # self.instantiate_cond_stage(cond_stage_config)
- self.cond_stage_forward = cond_stage_forward
- self.clip_denoised = False
- self.bbox_tokenizer = None
-
- self.restarted_from_ckpt = False
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys)
- self.restarted_from_ckpt = True
-
-
-
- def configure_sharded_model(self) -> None:
- self.model = DiffusionWrapper(self.unet_config, self.conditioning_key)
- count_params(self.model, verbose=True)
- if self.use_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
-
- self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps,
- linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s)
-
- self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
- if self.learn_logvar:
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- if self.ckpt_path is not None:
- self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
- self.restarted_from_ckpt = True
-
- # TODO()
- # for p in self.model.modules():
- # if not p.parameters().data.is_contiguous:
- # p.data = p.data.contiguous()
-
- self.instantiate_first_stage(self.first_stage_config)
- self.instantiate_cond_stage(self.cond_stage_config)
-
- def make_cond_schedule(self, ):
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
- self.cond_ids[:self.num_timesteps_cond] = ids
-
-
-
- @rank_zero_only
- @torch.no_grad()
- # def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
- def on_train_batch_start(self, batch, batch_idx):
- # only for very first batch
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
- # set rescale weight to 1./std of encodings
- print("### USING STD-RESCALING ###")
- x = super().get_input(batch, self.first_stage_key)
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- del self.scale_factor
- self.register_buffer('scale_factor', 1. / z.flatten().std())
- print(f"setting self.scale_factor to {self.scale_factor}")
- print("### USING STD-RESCALING ###")
-
- def register_schedule(self,
- given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
-
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
-
- def instantiate_first_stage(self, config):
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()
- self.first_stage_model.train = disabled_train
- for param in self.first_stage_model.parameters():
- param.requires_grad = False
-
- def instantiate_cond_stage(self, config):
- if not self.cond_stage_trainable:
- if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
- self.cond_stage_model = self.first_stage_model
- elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
- self.cond_stage_model = None
- # self.be_unconditional = True
- else:
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()
- self.cond_stage_model.train = disabled_train
- for param in self.cond_stage_model.parameters():
- param.requires_grad = False
- else:
- assert config != '__is_first_stage__'
- assert config != '__is_unconditional__'
- model = instantiate_from_config(config)
- self.cond_stage_model = model
-
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
- denoise_row = []
- for zd in tqdm(samples, desc=desc):
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
- n_imgs_per_row = len(denoise_row)
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
- return self.scale_factor * z
-
- def get_learned_conditioning(self, c):
- if self.cond_stage_forward is None:
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
- c = self.cond_stage_model.encode(c)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- else:
- c = self.cond_stage_model(c)
- else:
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
- return c
-
- def meshgrid(self, h, w):
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
-
- arr = torch.cat([y, x], dim=-1)
- return arr
-
- def delta_border(self, h, w):
- """
- :param h: height
- :param w: width
- :return: normalized distance to image border,
- wtith min distance = 0 at border and max dist = 0.5 at image center
- """
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
- arr = self.meshgrid(h, w) / lower_right_corner
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
- return edge_dist
-
- def get_weighting(self, h, w, Ly, Lx, device):
- weighting = self.delta_border(h, w)
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
- self.split_input_params["clip_max_weight"], )
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
-
- if self.split_input_params["tie_braker"]:
- L_weighting = self.delta_border(Ly, Lx)
- L_weighting = torch.clip(L_weighting,
- self.split_input_params["clip_min_tie_weight"],
- self.split_input_params["clip_max_tie_weight"])
-
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
- weighting = weighting * L_weighting
- return weighting
-
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
- """
- :param x: img of size (bs, c, h, w)
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
- """
- bs, nc, h, w = x.shape
-
- # number of crops in image
- Ly = (h - kernel_size[0]) // stride[0] + 1
- Lx = (w - kernel_size[1]) // stride[1] + 1
-
- if uf == 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
-
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
-
- elif uf > 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
- dilation=1, padding=0,
- stride=(stride[0] * uf, stride[1] * uf))
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
-
- elif df > 1 and uf == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
- dilation=1, padding=0,
- stride=(stride[0] // df, stride[1] // df))
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
-
- else:
- raise NotImplementedError
-
- return fold, unfold, normalization, weighting
-
- @torch.no_grad()
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
- cond_key=None, return_original_cond=False, bs=None):
- x = super().get_input(batch, k)
- if bs is not None:
- x = x[:bs]
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
-
- if self.model.conditioning_key is not None:
- if cond_key is None:
- cond_key = self.cond_stage_key
- if cond_key != self.first_stage_key:
- if cond_key in ['caption', 'coordinates_bbox', 'txt']:
- xc = batch[cond_key]
- elif cond_key == 'class_label':
- xc = batch
- else:
- xc = super().get_input(batch, cond_key).to(self.device)
- else:
- xc = x
- if not self.cond_stage_trainable or force_c_encode:
- if isinstance(xc, dict) or isinstance(xc, list):
- # import pudb; pudb.set_trace()
- c = self.get_learned_conditioning(xc)
- else:
- c = self.get_learned_conditioning(xc.to(self.device))
- else:
- c = xc
- if bs is not None:
- c = c[:bs]
-
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- ckey = __conditioning_keys__[self.model.conditioning_key]
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
-
- else:
- c = None
- xc = None
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- c = {'pos_x': pos_x, 'pos_y': pos_y}
- out = [z, c]
- if return_first_stage_outputs:
- xrec = self.decode_first_stage(z)
- out.extend([x, xrec])
- if return_original_cond:
- out.append(xc)
- return out
-
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- # same as above but without decorator
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- @torch.no_grad()
- def encode_first_stage(self, x):
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- df = self.split_input_params["vqf"]
- self.split_input_params['original_image_size'] = x.shape[-2:]
- bs, nc, h, w = x.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
- z = unfold(x) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
-
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization
- return decoded
-
- else:
- return self.first_stage_model.encode(x)
- else:
- return self.first_stage_model.encode(x)
-
- def shared_step(self, batch, **kwargs):
- x, c = self.get_input(batch, self.first_stage_key)
- loss = self(x, c)
- return loss
-
- def forward(self, x, c, *args, **kwargs):
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- if self.model.conditioning_key is not None:
- assert c is not None
- if self.cond_stage_trainable:
- c = self.get_learned_conditioning(c)
- if self.shorten_cond_schedule: # TODO: drop this option
- tc = self.cond_ids[t].to(self.device)
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
- return self.p_losses(x, c, t, *args, **kwargs)
-
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
- def apply_model(self, x_noisy, t, cond, return_ids=False):
- if isinstance(cond, dict):
- # hybrid case, cond is exptected to be a dict
- pass
- else:
- if not isinstance(cond, list):
- cond = [cond]
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
- cond = {key: cond}
-
- if hasattr(self, "split_input_params"):
- assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
-
- h, w = x_noisy.shape[-2:]
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
-
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
- c_key = next(iter(cond.keys())) # get key
- c = next(iter(cond.values())) # get value
- assert (len(c) == 1) # todo extend to list with more than one elem
- c = c[0] # get element
-
- c = unfold(c)
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
-
- elif self.cond_stage_key == 'coordinates_bbox':
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
-
- # assuming padding of unfold is always 0 and its dilation is always 1
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
- full_img_h, full_img_w = self.split_input_params['original_image_size']
- # as we are operating on latents, we need the factor from the original image size to the
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
- rescale_latent = 2 ** (num_downs)
-
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
- # need to rescale the tl patch coordinates to be in between (0,1)
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
- for patch_nr in range(z.shape[-1])]
-
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
- patch_limits = [(x_tl, y_tl,
- rescale_latent * ks[0] / full_img_w,
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
-
- # tokenize crop coordinates for the bounding boxes of the respective patches
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
- print(patch_limits_tknzd[0].shape)
- # cut tknzd crop position from conditioning
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
- print(cut_cond.shape)
-
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
- print(adapted_cond.shape)
- adapted_cond = self.get_learned_conditioning(adapted_cond)
- print(adapted_cond.shape)
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
- print(adapted_cond.shape)
-
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
-
- else:
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
-
- # apply model by loop over crops
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
- assert not isinstance(output_list[0],
- tuple) # todo cant deal with multiple model outputs check this never happens
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- x_recon = fold(o) / normalization
-
- else:
- x_recon = self.model(x_noisy, t, **cond)
-
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
-
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
-
- def _prior_bpd(self, x_start):
- """
- Get the prior KL term for the variational lower-bound, measured in
- bits-per-dim.
- This term can't be optimized, as it only depends on the encoder.
- :param x_start: the [N x C x ...] tensor of inputs.
- :return: a batch of [N] KL values (in bits), one per batch element.
- """
- batch_size = x_start.shape[0]
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
- return mean_flat(kl_prior) / np.log(2.0)
-
- def p_losses(self, x_start, cond, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_output = self.apply_model(x_noisy, t, cond)
-
- loss_dict = {}
- prefix = 'train' if self.training else 'val'
-
- if self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "eps":
- target = noise
- else:
- raise NotImplementedError()
-
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
-
- logvar_t = self.logvar[t].to(self.device)
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
- if self.learn_logvar:
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
- loss_dict.update({'logvar': self.logvar.data.mean()})
-
- loss = self.l_simple_weight * loss.mean()
-
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
- loss += (self.original_elbo_weight * loss_vlb)
- loss_dict.update({f'{prefix}/loss': loss})
-
- return loss, loss_dict
-
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
- return_x0=False, score_corrector=None, corrector_kwargs=None):
- t_in = t
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
-
- if score_corrector is not None:
- assert self.parameterization == "eps"
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
-
- if return_codebook_ids:
- model_out, logits = model_out
-
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- else:
- raise NotImplementedError()
-
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
- if quantize_denoised:
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- if return_codebook_ids:
- return model_mean, posterior_variance, posterior_log_variance, logits
- elif return_x0:
- return model_mean, posterior_variance, posterior_log_variance, x_recon
- else:
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
- b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
- return_codebook_ids=return_codebook_ids,
- quantize_denoised=quantize_denoised,
- return_x0=return_x0,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if return_codebook_ids:
- raise DeprecationWarning("Support dropped.")
- model_mean, _, model_log_variance, logits = outputs
- elif return_x0:
- model_mean, _, model_log_variance, x0 = outputs
- else:
- model_mean, _, model_log_variance = outputs
-
- noise = noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
-
- if return_codebook_ids:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
- if return_x0:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
- else:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
- log_every_t=None):
- if not log_every_t:
- log_every_t = self.log_every_t
- timesteps = self.num_timesteps
- if batch_size is not None:
- b = batch_size if batch_size is not None else shape[0]
- shape = [batch_size] + list(shape)
- else:
- b = batch_size = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=self.device)
- else:
- img = x_T
- intermediates = []
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
- total=timesteps) if verbose else reversed(
- range(0, timesteps))
- if type(temperature) == float:
- temperature = [temperature] * timesteps
-
- for i in iterator:
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img, x0_partial = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised, return_x0=True,
- temperature=temperature[i], noise_dropout=noise_dropout,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if mask is not None:
- assert x0 is not None
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_loop(self, cond, shape, return_intermediates=False,
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, start_T=None,
- log_every_t=None):
-
- if not log_every_t:
- log_every_t = self.log_every_t
- device = self.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- intermediates = [img]
- if timesteps is None:
- timesteps = self.num_timesteps
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
- range(0, timesteps))
-
- if mask is not None:
- assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
-
- for i in iterator:
- ts = torch.full((b,), i, device=device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised)
- if mask is not None:
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
-
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
- verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None,**kwargs):
- if shape is None:
- shape = (batch_size, self.channels, self.image_size, self.image_size)
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
- return self.p_sample_loop(cond,
- shape,
- return_intermediates=return_intermediates, x_T=x_T,
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
- mask=mask, x0=x0)
-
- @torch.no_grad()
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
-
- if ddim:
- ddim_sampler = DDIMSampler(self)
- shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
- shape,cond,verbose=False,**kwargs)
-
- else:
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True,**kwargs)
-
- return samples, intermediates
-
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, **kwargs):
-
- use_ddim = ddim_steps is not None
-
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
- return_first_stage_outputs=True,
- force_c_encode=True,
- return_original_cond=True,
- bs=N)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reconstruction"] = xrec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
- log["conditioning"] = xc
- elif self.cond_stage_key == 'class_label':
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
- log['conditioning'] = xc
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
- self.first_stage_model, IdentityFirstStage):
- # also display when quantizing x0 while sampling
- with self.ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta,
- quantize_denoised=True)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
- # quantize_denoised=True)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_x0_quantized"] = x_samples
-
- if inpaint:
- # make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
- mask = torch.ones(N, h, w).to(self.device)
- # zeros will be filled in
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
- mask = mask[:, None, ...]
- with self.ema_scope("Plotting Inpaint"):
-
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_inpainting"] = x_samples
- log["mask"] = mask
-
- # outpaint
- with self.ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_outpainting"] = x_samples
-
- if plot_progressive_rows:
- with self.ema_scope("Plotting Progressives"):
- img, progressives = self.progressive_denoising(c,
- shape=(self.channels, self.image_size, self.image_size),
- batch_size=N)
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
- log["progressive_row"] = prog_row
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.cond_stage_trainable:
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
- params = params + list(self.cond_stage_model.parameters())
- if self.learn_logvar:
- print('Diffusion model optimizing logvar')
- params.append(self.logvar)
- from colossalai.nn.optimizer import HybridAdam
- opt = HybridAdam(params, lr=lr)
- # opt = torch.optim.AdamW(params, lr=lr)
- if self.use_scheduler:
- assert 'target' in self.scheduler_config
- scheduler = instantiate_from_config(self.scheduler_config)
-
- rank_zero_info("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [opt], scheduler
- return opt
-
- @torch.no_grad()
- def to_rgb(self, x):
- x = x.float()
- if not hasattr(self, "colorize"):
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
- x = nn.functional.conv2d(x, weight=self.colorize)
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
- return x
-
-
-class DiffusionWrapper(pl.LightningModule):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
-
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t)
- elif self.conditioning_key == 'concat':
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t)
- elif self.conditioning_key == 'crossattn':
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc)
- elif self.conditioning_key == 'hybrid':
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc)
- elif self.conditioning_key == 'adm':
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc)
- else:
- raise NotImplementedError()
-
- return out
-
-
-class Layout2ImgDiffusion(LatentDiffusion):
- # TODO: move all layout-specific hacks to this class
- def __init__(self, cond_stage_key, *args, **kwargs):
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
-
- def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
-
- key = 'train' if self.training else 'validation'
- dset = self.trainer.datamodule.datasets[key]
- mapper = dset.conditional_builders[self.cond_stage_key]
-
- bbox_imgs = []
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
- bbox_imgs.append(bboximg)
-
- cond_img = torch.stack(bbox_imgs, dim=0)
- logs['bbox_image'] = cond_img
- return logs
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py
deleted file mode 100644
index 78eeb1003..000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py
+++ /dev/null
@@ -1,236 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from functools import partial
-
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
-
-
-class PLMSSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
- if ddim_eta != 0:
- raise ValueError('ddim_eta must be 0 for PLMS')
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
- alphas_cumprod = self.model.alphas_cumprod
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer('betas', to_torch(self.model.betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
- self.register_buffer('ddim_sigmas', ddim_sigmas)
- self.register_buffer('ddim_alphas', ddim_alphas)
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
-
- @torch.no_grad()
- def sample(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for PLMS sampling is {size}')
-
- samples, intermediates = self.plms_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def plms_sampling(self, cond, shape,
- x_T=None, ddim_use_original_steps=False,
- callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, log_every_t=100,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- if timesteps is None:
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- print(f"Running PLMS Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
- old_eps = []
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
- ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
-
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
- img = img_orig * mask + (1. - mask) * img
-
- outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised, temperature=temperature,
- noise_dropout=noise_dropout, score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- old_eps=old_eps, t_next=ts_next)
- img, pred_x0, e_t = outs
- old_eps.append(e_t)
- if len(old_eps) >= 4:
- old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates['x_inter'].append(img)
- intermediates['pred_x0'].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
- b, *_, device = *x.shape, x.device
-
- def get_model_output(x, t):
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- return e_t
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- e_t = get_model_output(x, t)
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = get_model_output(x_prev, t_next)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- elif len(old_eps) >= 3:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- return x_prev, pred_x0, e_t
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/attention.py b/examples/tutorial/stable_diffusion/ldm/modules/attention.py
deleted file mode 100644
index 3401ceafd..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/attention.py
+++ /dev/null
@@ -1,314 +0,0 @@
-from inspect import isfunction
-import math
-import torch
-import torch.nn.functional as F
-from torch import nn, einsum
-from einops import rearrange, repeat
-
-from torch.utils import checkpoint
-
-try:
- from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
- FlASH_AVAILABLE = True
-except:
- FlASH_AVAILABLE = False
-
-USE_FLASH = False
-
-
-def enable_flash_attention():
- global USE_FLASH
- USE_FLASH = True
- if FlASH_AVAILABLE is False:
- print("Please install flash attention to activate new attention kernel.\n" +
- "Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
-
-
-def exists(val):
- return val is not None
-
-
-def uniq(arr):
- return{el: True for el in arr}.keys()
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def max_neg_value(t):
- return -torch.finfo(t.dtype).max
-
-
-def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
-
-
-# feedforward
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
- nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim)
-
- self.net = nn.Sequential(
- project_in,
- nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def Normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-class LinearAttention(nn.Module):
- def __init__(self, dim, heads=4, dim_head=32):
- super().__init__()
- self.heads = heads
- hidden_dim = dim_head * heads
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
-
- def forward(self, x):
- b, c, h, w = x.shape
- qkv = self.to_qkv(x)
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
- k = k.softmax(dim=-1)
- context = torch.einsum('bhdn,bhen->bhde', k, v)
- out = torch.einsum('bhde,bhdn->bhen', context, q)
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
- return self.to_out(out)
-
-
-class SpatialSelfAttention(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b,c,h,w = q.shape
- q = rearrange(q, 'b c h w -> b (h w) c')
- k = rearrange(k, 'b c h w -> b c (h w)')
- w_ = torch.einsum('bij,bjk->bik', q, k)
-
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = rearrange(v, 'b c h w -> b c (h w)')
- w_ = rearrange(w_, 'b i j -> b j i')
- h_ = torch.einsum('bij,bjk->bik', v, w_)
- h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
-
- self.scale = dim_head ** -0.5
- self.heads = heads
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
-
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
- nn.Dropout(dropout)
- )
-
- def forward(self, x, context=None, mask=None):
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- dim_head = q.shape[-1] / self.heads
-
- if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
- dim_head <= 128 and (dim_head % 8) == 0:
- # print("in flash")
- if q.shape[1] == k.shape[1]:
- out = self._flash_attention_qkv(q, k, v)
- else:
- out = self._flash_attention_q_kv(q, k, v)
- else:
- out = self._native_attention(q, k, v, self.heads, mask)
-
- return self.to_out(out)
-
- def _native_attention(self, q, k, v, h, mask):
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
- if exists(mask):
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
- # attention, what we cannot get enough of
- out = sim.softmax(dim=-1)
- out = einsum('b i j, b j d -> b i d', out, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return out
-
- def _flash_attention_qkv(self, q, k, v):
- qkv = torch.stack([q, k, v], dim=2)
- b = qkv.shape[0]
- n = qkv.shape[1]
- qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
- out = flash_attention_qkv(qkv, self.scale, b, n)
- out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
- return out
-
- def _flash_attention_q_kv(self, q, k, v):
- kv = torch.stack([k, v], dim=2)
- b = q.shape[0]
- q_seqlen = q.shape[1]
- kv_seqlen = kv.shape[1]
- q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
- kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
- out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
- out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
- return out
-
-
-class BasicTransformerBlock(nn.Module):
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
- super().__init__()
- self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.use_checkpoint = use_checkpoint
-
- def forward(self, x, context=None):
-
-
- if self.use_checkpoint:
- return checkpoint(self._forward, x, context)
- else:
- return self._forward(x, context)
-
- def _forward(self, x, context=None):
- x = self.attn1(self.norm1(x)) + x
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
-
-
-
-class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data.
- First, project the input (aka embedding)
- and reshape to b, t, d.
- Then apply standard transformer action.
- Finally, reshape to image
- """
- def __init__(self, in_channels, n_heads, d_head,
- depth=1, dropout=0., context_dim=None, use_checkpoint=False):
- super().__init__()
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
-
- self.proj_in = nn.Conv2d(in_channels,
- inner_dim,
- kernel_size=1,
- stride=1,
- padding=0)
-
- self.transformer_blocks = nn.ModuleList(
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
- for d in range(depth)]
- )
-
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0))
-
-
- def forward(self, x, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- x = self.proj_in(x)
- x = rearrange(x, 'b c h w -> b (h w) c')
- x = x.contiguous()
- for block in self.transformer_blocks:
- x = block(x, context=context)
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
- x = x.contiguous()
- x = self.proj_out(x)
- return x + x_in
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py
deleted file mode 100644
index 3c28492c5..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py
+++ /dev/null
@@ -1,862 +0,0 @@
-# pytorch_diffusion + derived encoder decoder
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import rearrange
-
-from ldm.util import instantiate_from_config
-from ldm.modules.attention import LinearAttention
-
-
-def get_timestep_embedding(timesteps, embedding_dim):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models:
- From Fairseq.
- Build sinusoidal embeddings.
- This matches the implementation in tensor2tensor, but differs slightly
- from the description in Section 3.5 of "Attention Is All You Need".
- """
- assert len(timesteps.shape) == 1
-
- half_dim = embedding_dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
- emb = emb.to(device=timesteps.device)
- emb = timesteps.float()[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
- return emb
-
-
-def nonlinearity(x):
- # swish
- return x*torch.sigmoid(x)
-
-
-def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=2,
- padding=0)
-
- def forward(self, x):
- if self.with_conv:
- pad = (0,1,0,1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class ResnetBlock(nn.Module):
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels,
- out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(out_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- else:
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
-
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
-
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
-
- return x+h
-
-
-class LinAttnBlock(LinearAttention):
- """to match AttnBlock usage"""
- def __init__(self, in_channels):
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b,c,h,w = q.shape
- q = q.reshape(b,c,h*w)
- q = q.permute(0,2,1) # b,hw,c
- k = k.reshape(b,c,h*w) # b,c,hw
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b,c,h*w)
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b,c,h,w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-def make_attn(in_channels, attn_type="vanilla"):
- assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
- if attn_type == "vanilla":
- return AttnBlock(in_channels)
- elif attn_type == "none":
- return nn.Identity(in_channels)
- else:
- return LinAttnBlock(in_channels)
-
-class temb_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Model(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = self.ch*4
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- self.use_timestep = use_timestep
- if self.use_timestep:
- # timestep embedding
- # self.temb = nn.Module()
- self.temb = temb_module()
- self.temb.dense = nn.ModuleList([
- torch.nn.Linear(self.ch,
- self.temb_ch),
- torch.nn.Linear(self.temb_ch,
- self.temb_ch),
- ])
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # down = nn.Module()
- down = Down_module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions-1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- skip_in = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
- if i_block == self.num_res_blocks:
- skip_in = ch*in_ch_mult[i_level]
- block.append(ResnetBlock(in_channels=block_in+skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # up = nn.Module()
- up = Up_module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x, t=None, context=None):
- #assert x.shape[2] == x.shape[3] == self.resolution
- if context is not None:
- # assume aligned context, cat along channel axis
- x = torch.cat((x, context), dim=1)
- if self.use_timestep:
- # timestep embedding
- assert t is not None
- temb = get_timestep_embedding(t, self.ch)
- temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
- temb = self.temb.dense[1](temb)
- else:
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions-1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
- def get_last_layer(self):
- return self.conv_out.weight
-
-class Down_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Up_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Mid_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-
-class Encoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
- **ignore_kwargs):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # down = nn.Module()
- down = Down_module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions-1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- 2*z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- # timestep embedding
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions-1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class Decoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
- attn_type="vanilla", **ignorekwargs):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,)+tuple(ch_mult)
- block_in = ch*ch_mult[self.num_resolutions-1]
- curr_res = resolution // 2**(self.num_resolutions-1)
- self.z_shape = (1,z_channels,curr_res,curr_res)
- print("Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)))
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(z_channels,
- block_in,
- kernel_size=3,
- stride=1,
- padding=1)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # up = nn.Module()
- up = Up_module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, z):
- #assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
-
-
-class SimpleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, *args, **kwargs):
- super().__init__()
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- nn.Conv2d(2*in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True)])
- # end
- self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- for i, layer in enumerate(self.model):
- if i in [1,2,3]:
- x = layer(x, None)
- else:
- x = layer(x)
-
- h = self.norm_out(x)
- h = nonlinearity(h)
- x = self.conv_out(h)
- return x
-
-
-class UpsampleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
- ch_mult=(2,2), dropout=0.0):
- super().__init__()
- # upsampling
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.res_blocks = nn.ModuleList()
- self.upsample_blocks = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- res_block = []
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- res_block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- self.res_blocks.append(nn.ModuleList(res_block))
- if i_level != self.num_resolutions - 1:
- self.upsample_blocks.append(Upsample(block_in, True))
- curr_res = curr_res * 2
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- # upsampling
- h = x
- for k, i_level in enumerate(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.res_blocks[i_level][i_block](h, None)
- if i_level != self.num_resolutions - 1:
- h = self.upsample_blocks[k](h)
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class LatentRescaler(nn.Module):
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
- super().__init__()
- # residual block, interpolate, residual block
- self.factor = factor
- self.conv_in = nn.Conv2d(in_channels,
- mid_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
- self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
-
- self.conv_out = nn.Conv2d(mid_channels,
- out_channels,
- kernel_size=1,
- )
-
- def forward(self, x):
- x = self.conv_in(x)
- for block in self.res_block1:
- x = block(x, None)
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
- x = self.attn(x)
- for block in self.res_block2:
- x = block(x, None)
- x = self.conv_out(x)
- return x
-
-
-class MergedRescaleEncoder(nn.Module):
- def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
- super().__init__()
- intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
- z_channels=intermediate_chn, double_z=False, resolution=resolution,
- attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
- out_ch=None)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
- mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
-
- def forward(self, x):
- x = self.encoder(x)
- x = self.rescaler(x)
- return x
-
-
-class MergedRescaleDecoder(nn.Module):
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
- dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
- super().__init__()
- tmp_chn = z_channels*ch_mult[-1]
- self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
- resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
- ch_mult=ch_mult, resolution=resolution, ch=ch)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
- out_channels=tmp_chn, depth=rescale_module_depth)
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Upsampler(nn.Module):
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
- super().__init__()
- assert out_size >= in_size
- num_blocks = int(np.log2(out_size//in_size))+1
- factor_up = 1.+ (out_size % in_size)
- print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
- out_channels=in_channels)
- self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
- attn_resolutions=[], in_channels=None, ch=in_channels,
- ch_mult=[ch_mult for _ in range(num_blocks)])
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Resize(nn.Module):
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
- super().__init__()
- self.with_conv = learned
- self.mode = mode
- if self.with_conv:
- print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
- raise NotImplementedError()
- assert in_channels is not None
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=4,
- stride=2,
- padding=1)
-
- def forward(self, x, scale_factor=1.0):
- if scale_factor==1.0:
- return x
- else:
- x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
- return x
-
-class FirstStagePostProcessor(nn.Module):
-
- def __init__(self, ch_mult:list, in_channels,
- pretrained_model:nn.Module=None,
- reshape=False,
- n_channels=None,
- dropout=0.,
- pretrained_config=None):
- super().__init__()
- if pretrained_config is None:
- assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.pretrained_model = pretrained_model
- else:
- assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.instantiate_pretrained(pretrained_config)
-
- self.do_reshape = reshape
-
- if n_channels is None:
- n_channels = self.pretrained_model.encoder.ch
-
- self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
- self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
- stride=1,padding=1)
-
- blocks = []
- downs = []
- ch_in = n_channels
- for m in ch_mult:
- blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
- ch_in = m * n_channels
- downs.append(Downsample(ch_in, with_conv=False))
-
- self.model = nn.ModuleList(blocks)
- self.downsampler = nn.ModuleList(downs)
-
-
- def instantiate_pretrained(self, config):
- model = instantiate_from_config(config)
- self.pretrained_model = model.eval()
- # self.pretrained_model.train = False
- for param in self.pretrained_model.parameters():
- param.requires_grad = False
-
-
- @torch.no_grad()
- def encode_with_pretrained(self,x):
- c = self.pretrained_model.encode(x)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- return c
-
- def forward(self,x):
- z_fs = self.encode_with_pretrained(x)
- z = self.proj_norm(z_fs)
- z = self.proj(z)
- z = nonlinearity(z)
-
- for submodel, downmodel in zip(self.model,self.downsampler):
- z = submodel(z,temb=None)
- z = downmodel(z)
-
- if self.do_reshape:
- z = rearrange(z,'b c h w -> b (h w) c')
- return z
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
deleted file mode 100644
index 3aedc2205..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
+++ /dev/null
@@ -1,1152 +0,0 @@
-from abc import abstractmethod
-from functools import partial
-import math
-from typing import Iterable
-
-import numpy as np
-import torch
-import torch as th
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils import checkpoint
-
-from ldm.modules.diffusionmodules.util import (
- conv_nd,
- linear,
- avg_pool_nd,
- zero_module,
- normalization,
- timestep_embedding,
-)
-from ldm.modules.attention import SpatialTransformer
-
-
-# dummy replace
-def convert_module_to_f16(x):
- # for n,p in x.named_parameter():
- # print(f"convert module {n} to_f16")
- # p.data = p.data.half()
- pass
-
-def convert_module_to_f32(x):
- pass
-
-
-## go
-class AttentionPool2d(nn.Module):
- """
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- """
-
- def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
- ):
- super().__init__()
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
- self.num_heads = embed_dim // num_heads_channels
- self.attention = QKVAttention(self.num_heads)
-
- def forward(self, x):
- b, c, *_spatial = x.shape
- x = x.reshape(b, c, -1) # NC(HW)
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
- x = self.qkv_proj(x)
- x = self.attention(x)
- x = self.c_proj(x)
- return x[:, :, 0]
-
-
-class TimestepBlock(nn.Module):
- """
- Any module where forward() takes timestep embeddings as a second argument.
- """
-
- @abstractmethod
- def forward(self, x, emb):
- """
- Apply the module to `x` given `emb` timestep embeddings.
- """
-
-
-class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
- """
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
- """
-
- def forward(self, x, emb, context=None):
- for layer in self:
- if isinstance(layer, TimestepBlock):
- x = layer(x, emb)
- elif isinstance(layer, SpatialTransformer):
- x = layer(x, context)
- else:
- x = layer(x)
- return x
-
-
-class Upsample(nn.Module):
- """
- An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- if use_conv:
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- if self.dims == 3:
- x = F.interpolate(
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
- )
- else:
- x = F.interpolate(x, scale_factor=2, mode="nearest")
- if self.use_conv:
- x = self.conv(x)
- return x
-
-class TransposedUpsample(nn.Module):
- 'Learned 2x upsampling without padding'
- def __init__(self, channels, out_channels=None, ks=5):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
-
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
-
- def forward(self,x):
- return self.up(x)
-
-
-class Downsample(nn.Module):
- """
- A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- stride = 2 if dims != 3 else (1, 2, 2)
- if use_conv:
- self.op = conv_nd(
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
- )
- else:
- assert self.channels == self.out_channels
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- return self.op(x)
-
-
-class ResBlock(TimestepBlock):
- """
- A residual block that can optionally change the number of channels.
- :param channels: the number of input channels.
- :param emb_channels: the number of timestep embedding channels.
- :param dropout: the rate of dropout.
- :param out_channels: if specified, the number of out channels.
- :param use_conv: if True and out_channels is specified, use a spatial
- convolution instead of a smaller 1x1 convolution to change the
- channels in the skip connection.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param use_checkpoint: if True, use gradient checkpointing on this module.
- :param up: if True, use this block for upsampling.
- :param down: if True, use this block for downsampling.
- """
-
- def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
- ):
- super().__init__()
- self.channels = channels
- self.emb_channels = emb_channels
- self.dropout = dropout
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_checkpoint = use_checkpoint
- self.use_scale_shift_norm = use_scale_shift_norm
-
- self.in_layers = nn.Sequential(
- normalization(channels),
- nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
- )
-
- self.updown = up or down
-
- if up:
- self.h_upd = Upsample(channels, False, dims)
- self.x_upd = Upsample(channels, False, dims)
- elif down:
- self.h_upd = Downsample(channels, False, dims)
- self.x_upd = Downsample(channels, False, dims)
- else:
- self.h_upd = self.x_upd = nn.Identity()
-
- self.emb_layers = nn.Sequential(
- nn.SiLU(),
- linear(
- emb_channels,
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
- ),
- )
- self.out_layers = nn.Sequential(
- normalization(self.out_channels),
- nn.SiLU(),
- nn.Dropout(p=dropout),
- zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
- ),
- )
-
- if self.out_channels == channels:
- self.skip_connection = nn.Identity()
- elif use_conv:
- self.skip_connection = conv_nd(
- dims, channels, self.out_channels, 3, padding=1
- )
- else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
-
- def forward(self, x, emb):
- """
- Apply the block to a Tensor, conditioned on a timestep embedding.
- :param x: an [N x C x ...] Tensor of features.
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
- :return: an [N x C x ...] Tensor of outputs.
- """
- if self.use_checkpoint:
- return checkpoint(self._forward, x, emb)
- else:
- return self._forward(x, emb)
-
-
- def _forward(self, x, emb):
- if self.updown:
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
- h = in_rest(x)
- h = self.h_upd(h)
- x = self.x_upd(x)
- h = in_conv(h)
- else:
- h = self.in_layers(x)
- emb_out = self.emb_layers(emb).type(h.dtype)
- while len(emb_out.shape) < len(h.shape):
- emb_out = emb_out[..., None]
- if self.use_scale_shift_norm:
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
- scale, shift = th.chunk(emb_out, 2, dim=1)
- h = out_norm(h) * (1 + scale) + shift
- h = out_rest(h)
- else:
- h = h + emb_out
- h = self.out_layers(h)
- return self.skip_connection(x) + h
-
-
-class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other.
- Originally ported from here, but adapted to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- """
-
- def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
- ):
- super().__init__()
- self.channels = channels
- if num_head_channels == -1:
- self.num_heads = num_heads
- else:
- assert (
- channels % num_head_channels == 0
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
- self.num_heads = channels // num_head_channels
- self.use_checkpoint = use_checkpoint
- self.norm = normalization(channels)
- self.qkv = conv_nd(1, channels, channels * 3, 1)
- if use_new_attention_order:
- # split qkv before split heads
- self.attention = QKVAttention(self.num_heads)
- else:
- # split heads before split qkv
- self.attention = QKVAttentionLegacy(self.num_heads)
-
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
-
- def forward(self, x):
- if self.use_checkpoint:
- return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- #return pt_checkpoint(self._forward, x) # pytorch
- else:
- return self._forward(x)
-
- def _forward(self, x):
- b, c, *spatial = x.shape
- x = x.reshape(b, c, -1)
- qkv = self.qkv(self.norm(x))
- h = self.attention(qkv)
- h = self.proj_out(h)
- return (x + h).reshape(b, c, *spatial)
-
-
-def count_flops_attn(model, _x, y):
- """
- A counter for the `thop` package to count the operations in an
- attention operation.
- Meant to be used like:
- macs, params = thop.profile(
- model,
- inputs=(inputs, timestamps),
- custom_ops={QKVAttention: QKVAttention.count_flops},
- )
- """
- b, c, *spatial = y[0].shape
- num_spatial = int(np.prod(spatial))
- # We perform two matmuls with the same number of ops.
- # The first computes the weight matrix, the second computes
- # the combination of the value vectors.
- matmul_ops = 2 * b * (num_spatial ** 2) * c
- model.total_ops += th.DoubleTensor([matmul_ops])
-
-
-class QKVAttentionLegacy(nn.Module):
- """
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts", q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v)
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class QKVAttention(nn.Module):
- """
- A module which performs QKV attention and splits in a different order.
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.chunk(3, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts",
- (q * scale).view(bs * self.n_heads, ch, length),
- (k * scale).view(bs * self.n_heads, ch, length),
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class UNetModel(nn.Module):
- """
- The full UNet model with attention and timestep embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- from_pretrained: str=None
- ):
- super().__init__()
- if use_spatial_transformer:
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
-
- if context_dim is not None:
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
- from omegaconf.listconfig import ListConfig
- if type(context_dim) == ListConfig:
- context_dim = list(context_dim)
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- if num_heads == -1:
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
-
- if num_head_channels == -1:
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
-
- self.image_size = image_size
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.num_classes = num_classes
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
- self.predict_codebook_ids = n_embed is not None
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- if self.num_classes is not None:
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
-
- self.output_blocks = nn.ModuleList([])
- for level, mult in list(enumerate(channel_mult))[::-1]:
- for i in range(num_res_blocks + 1):
- ich = input_block_chans.pop()
- layers = [
- ResBlock(
- ch + ich,
- time_embed_dim,
- dropout,
- out_channels=model_channels * mult,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = model_channels * mult
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads_upsample,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- )
- )
- if level and i == num_res_blocks:
- out_ch = ch
- layers.append(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- up=True,
- )
- if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
- )
- ds //= 2
- self.output_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
-
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
- )
- if self.predict_codebook_ids:
- self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
- # if use_fp16:
- # self.convert_to_fp16()
- from diffusers.modeling_utils import load_state_dict
- if from_pretrained is not None:
- state_dict = load_state_dict(from_pretrained)
- self._load_pretrained_model(state_dict)
-
- def _input_blocks_mapping(self, input_dict):
- res_dict = {}
- for key_, value_ in input_dict.items():
- id_0 = int(key_[13])
- if "resnets" in key_:
- id_1 = int(key_[23])
- target_id = 3 * id_0 + 1 + id_1
- post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\
- .replace('norm1', 'in_layers.0')\
- .replace('norm2', 'out_layers.0')\
- .replace('conv1', 'in_layers.2')\
- .replace('conv2', 'out_layers.3')\
- .replace('conv_shortcut', 'skip_connection')
- res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_
- elif "attentions" in key_:
- id_1 = int(key_[26])
- target_id = 3 * id_0 + 1 + id_1
- post_fix = key_[28:]
- res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_
- elif "downsamplers" in key_:
- post_fix = key_[35:]
- target_id = 3 * (id_0 + 1)
- res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_
- return res_dict
-
-
- def _mid_blocks_mapping(self, mid_dict):
- res_dict = {}
- for key_, value_ in mid_dict.items():
- if "resnets" in key_:
- temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \
- .replace('norm1', 'in_layers.0') \
- .replace('norm2', 'out_layers.0') \
- .replace('conv1', 'in_layers.2') \
- .replace('conv2', 'out_layers.3') \
- .replace('conv_shortcut', 'skip_connection')\
- .replace('middle_block.resnets.0', 'middle_block.0')\
- .replace('middle_block.resnets.1', 'middle_block.2')
- res_dict[temp_key_] = value_
- elif "attentions" in key_:
- res_dict[key_.replace('attentions.0', '1')] = value_
- return res_dict
-
- def _other_blocks_mapping(self, other_dict):
- res_dict = {}
- for key_, value_ in other_dict.items():
- tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\
- .replace('time_embedding.linear_1', 'time_embed.0')\
- .replace('time_embedding.linear_2', 'time_embed.2')\
- .replace('conv_norm_out', 'out.0')\
- .replace('conv_out', 'out.2')
- res_dict[tmp_key] = value_
- return res_dict
-
-
- def _output_blocks_mapping(self, output_dict):
- res_dict = {}
- for key_, value_ in output_dict.items():
- id_0 = int(key_[14])
- if "resnets" in key_:
- id_1 = int(key_[24])
- target_id = 3 * id_0 + id_1
- post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \
- .replace('norm1', 'in_layers.0') \
- .replace('norm2', 'out_layers.0') \
- .replace('conv1', 'in_layers.2') \
- .replace('conv2', 'out_layers.3') \
- .replace('conv_shortcut', 'skip_connection')
- res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_
- elif "attentions" in key_:
- id_1 = int(key_[27])
- target_id = 3 * id_0 + id_1
- post_fix = key_[29:]
- res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_
- elif "upsamplers" in key_:
- post_fix = key_[34:]
- target_id = 3 * (id_0 + 1) - 1
- mid_str = '.2.conv.' if target_id != 2 else '.1.conv.'
- res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_
- return res_dict
-
- def _state_key_mapping(self, state_dict: dict):
- import re
- res_dict = {}
- input_dict = {}
- mid_dict = {}
- output_dict = {}
- other_dict = {}
- for key_, value_ in state_dict.items():
- if "down_blocks" in key_:
- input_dict[key_.replace('down_blocks', 'input_blocks')] = value_
- elif "up_blocks" in key_:
- output_dict[key_.replace('up_blocks', 'output_blocks')] = value_
- elif "mid_block" in key_:
- mid_dict[key_.replace('mid_block', 'middle_block')] = value_
- else:
- other_dict[key_] = value_
-
- input_dict = self._input_blocks_mapping(input_dict)
- output_dict = self._output_blocks_mapping(output_dict)
- mid_dict = self._mid_blocks_mapping(mid_dict)
- other_dict = self._other_blocks_mapping(other_dict)
- # key_list = state_dict.keys()
- # key_str = " ".join(key_list)
-
- # for key_, val_ in state_dict.items():
- # key_ = key_.replace("down_blocks", "input_blocks")\
- # .replace("up_blocks", 'output_blocks')
- # res_dict[key_] = val_
- res_dict.update(input_dict)
- res_dict.update(output_dict)
- res_dict.update(mid_dict)
- res_dict.update(other_dict)
-
- return res_dict
-
- def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
- state_dict = self._state_key_mapping(state_dict)
- model_state_dict = self.state_dict()
- loaded_keys = [k for k in state_dict.keys()]
- expected_keys = list(model_state_dict.keys())
- original_loaded_keys = loaded_keys
- missing_keys = list(set(expected_keys) - set(loaded_keys))
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
- if state_dict is not None:
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
- state_dict,
- model_state_dict,
- original_loaded_keys,
- ignore_mismatched_sizes,
- )
- error_msgs = self._load_state_dict_into_model(state_dict)
- return missing_keys, unexpected_keys, mismatched_keys, error_msgs
-
- def _load_state_dict_into_model(self, state_dict):
- # Convert old format to new format if needed from a PyTorch state_dict
- # copy state_dict so _load_from_state_dict can modify it
- state_dict = state_dict.copy()
- error_msgs = []
-
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
- # so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix=""):
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
- module._load_from_state_dict(*args)
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(self)
-
- return error_msgs
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
- self.output_blocks.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
- self.output_blocks.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param context: conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: an [N x C x ...] Tensor of outputs.
- """
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- h = self.middle_block(h, emb, context)
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(self.dtype)
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
-
-
-class EncoderUNetModel(nn.Module):
- """
- The half UNet model with attention and timestep embedding.
- For usage, see UNet.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- pool="adaptive",
- *args,
- **kwargs
- ):
- super().__init__()
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
- self.pool = pool
- if pool == "adaptive":
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d((1, 1)),
- zero_module(conv_nd(dims, ch, out_channels, 1)),
- nn.Flatten(),
- )
- elif pool == "attention":
- assert num_head_channels != -1
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- AttentionPool2d(
- (image_size // ds), ch, num_head_channels, out_channels
- ),
- )
- elif pool == "spatial":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- nn.ReLU(),
- nn.Linear(2048, self.out_channels),
- )
- elif pool == "spatial_v2":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- normalization(2048),
- nn.SiLU(),
- nn.Linear(2048, self.out_channels),
- )
- else:
- raise NotImplementedError(f"Unexpected {pool} pooling")
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :return: an [N x K] Tensor of outputs.
- """
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
-
- results = []
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = self.middle_block(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = th.cat(results, axis=-1)
- return self.out(h)
- else:
- h = h.type(self.dtype)
- return self.out(h)
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py
deleted file mode 100644
index a7db9369c..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# adopted from
-# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-# and
-# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-# and
-# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
-#
-# thanks!
-
-
-import os
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import repeat
-
-from ldm.util import instantiate_from_config
-
-
-def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if schedule == "linear":
- betas = (
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
- )
-
- elif schedule == "cosine":
- timesteps = (
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
- )
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
- alphas = torch.cos(alphas).pow(2)
- alphas = alphas / alphas[0]
- betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
-
- elif schedule == "sqrt_linear":
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
- elif schedule == "sqrt":
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
- else:
- raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
-
-
-def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
- if ddim_discr_method == 'uniform':
- c = num_ddpm_timesteps // num_ddim_timesteps
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
- elif ddim_discr_method == 'quad':
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
- else:
- raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
-
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
- if verbose:
- print(f'Selected timesteps for ddim sampler: {steps_out}')
- return steps_out
-
-
-def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
- # select alphas for computing the variance schedule
- alphas = alphacums[ddim_timesteps]
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
-
- # according the the formula provided in https://arxiv.org/abs/2010.02502
- sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
- if verbose:
- print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
- print(f'For the chosen value of eta, which is {eta}, '
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
- return sigmas, alphas, alphas_prev
-
-
-def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function,
- which defines the cumulative product of (1-beta) over time from t = [0,1].
- :param num_diffusion_timesteps: the number of betas to produce.
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
- produces the cumulative product of (1-beta) up to that
- part of the diffusion process.
- :param max_beta: the maximum beta to use; use values lower than 1 to
- prevent singularities.
- """
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas)
-
-
-def extract_into_tensor(a, t, x_shape):
- b, *_ = t.shape
- out = a.gather(-1, t)
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
-
-
-def checkpoint(func, inputs, params, flag):
- """
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
- """
- if flag:
- args = tuple(inputs) + tuple(params)
- return CheckpointFunction.apply(func, len(inputs), *args)
- else:
- return func(*inputs)
-
-
-class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, length, *args):
- ctx.run_function = run_function
- ctx.input_tensors = list(args[:length])
- ctx.input_params = list(args[length:])
-
- with torch.no_grad():
- output_tensors = ctx.run_function(*ctx.input_tensors)
- return output_tensors
-
- @staticmethod
- def backward(ctx, *output_grads):
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
- with torch.enable_grad():
- # Fixes a bug where the first op in run_function modifies the
- # Tensor storage in place, which is not allowed for detach()'d
- # Tensors.
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
- output_tensors = ctx.run_function(*shallow_copies)
- input_grads = torch.autograd.grad(
- output_tensors,
- ctx.input_tensors + ctx.input_params,
- output_grads,
- allow_unused=True,
- )
- del ctx.input_tensors
- del ctx.input_params
- del output_tensors
- return (None, None) + input_grads
-
-
-def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
- """
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
- """
- if not repeat_only:
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
- ).to(device=timesteps.device)
- args = timesteps[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- else:
- embedding = repeat(timesteps, 'b -> b d', d=dim)
- if use_fp16:
- return embedding.half()
- else:
- return embedding
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def scale_module(module, scale):
- """
- Scale the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().mul_(scale)
- return module
-
-
-def mean_flat(tensor):
- """
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def normalization(channels, precision=16):
- """
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
- """
- if precision == 16:
- return GroupNorm16(16, channels)
- else:
- return GroupNorm32(32, channels)
-
-
-# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
-class SiLU(nn.Module):
- def forward(self, x):
- return x * torch.sigmoid(x)
-
-class GroupNorm16(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.half()).type(x.dtype)
-
-class GroupNorm32(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.float()).type(x.dtype)
-
-def conv_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D convolution module.
- """
- if dims == 1:
- return nn.Conv1d(*args, **kwargs)
- elif dims == 2:
- return nn.Conv2d(*args, **kwargs)
- elif dims == 3:
- return nn.Conv3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-def linear(*args, **kwargs):
- """
- Create a linear module.
- """
- return nn.Linear(*args, **kwargs)
-
-
-def avg_pool_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D average pooling module.
- """
- if dims == 1:
- return nn.AvgPool1d(*args, **kwargs)
- elif dims == 2:
- return nn.AvgPool2d(*args, **kwargs)
- elif dims == 3:
- return nn.AvgPool3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-class HybridConditioner(nn.Module):
-
- def __init__(self, c_concat_config, c_crossattn_config):
- super().__init__()
- self.concat_conditioner = instantiate_from_config(c_concat_config)
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
-
- def forward(self, c_concat, c_crossattn):
- c_concat = self.concat_conditioner(c_concat)
- c_crossattn = self.crossattn_conditioner(c_crossattn)
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
-
-
-def noise_like(shape, device, repeat=False):
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
- noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py
deleted file mode 100644
index f2b8ef901..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import torch
-import numpy as np
-
-
-class AbstractDistribution:
- def sample(self):
- raise NotImplementedError()
-
- def mode(self):
- raise NotImplementedError()
-
-
-class DiracDistribution(AbstractDistribution):
- def __init__(self, value):
- self.value = value
-
- def sample(self):
- return self.value
-
- def mode(self):
- return self.value
-
-
-class DiagonalGaussianDistribution(object):
- def __init__(self, parameters, deterministic=False):
- self.parameters = parameters
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
- self.deterministic = deterministic
- self.std = torch.exp(0.5 * self.logvar)
- self.var = torch.exp(self.logvar)
- if self.deterministic:
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
-
- def sample(self):
- x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
- return x
-
- def kl(self, other=None):
- if self.deterministic:
- return torch.Tensor([0.])
- else:
- if other is None:
- return 0.5 * torch.sum(torch.pow(self.mean, 2)
- + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3])
- else:
- return 0.5 * torch.sum(
- torch.pow(self.mean - other.mean, 2) / other.var
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
- dim=[1, 2, 3])
-
- def nll(self, sample, dims=[1,2,3]):
- if self.deterministic:
- return torch.Tensor([0.])
- logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims)
-
- def mode(self):
- return self.mean
-
-
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
- Compute the KL divergence between two gaussians.
- Shapes are automatically broadcasted, so batches can be compared to
- scalars, among other use cases.
- """
- tensor = None
- for obj in (mean1, logvar1, mean2, logvar2):
- if isinstance(obj, torch.Tensor):
- tensor = obj
- break
- assert tensor is not None, "at least one argument must be a Tensor"
-
- # Force variances to be Tensors. Broadcasting helps convert scalars to
- # Tensors, but it does not work for torch.exp().
- logvar1, logvar2 = [
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
- for x in (logvar1, logvar2)
- ]
-
- return 0.5 * (
- -1.0
- + logvar2
- - logvar1
- + torch.exp(logvar1 - logvar2)
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
- )
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/ema.py b/examples/tutorial/stable_diffusion/ldm/modules/ema.py
deleted file mode 100644
index c8c75af43..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/ema.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import torch
-from torch import nn
-
-
-class LitEma(nn.Module):
- def __init__(self, model, decay=0.9999, use_num_upates=True):
- super().__init__()
- if decay < 0.0 or decay > 1.0:
- raise ValueError('Decay must be between 0 and 1')
-
- self.m_name2s_name = {}
- self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
- self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
- else torch.tensor(-1,dtype=torch.int))
-
- for name, p in model.named_parameters():
- if p.requires_grad:
- #remove as '.'-character is not allowed in buffers
- s_name = name.replace('.','')
- self.m_name2s_name.update({name:s_name})
- self.register_buffer(s_name,p.clone().detach().data)
-
- self.collected_params = []
-
- def forward(self,model):
- decay = self.decay
-
- if self.num_updates >= 0:
- self.num_updates += 1
- decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
-
- one_minus_decay = 1.0 - decay
-
- with torch.no_grad():
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
-
- for key in m_param:
- if m_param[key].requires_grad:
- sname = self.m_name2s_name[key]
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
- shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
- else:
- assert not key in self.m_name2s_name
-
- def copy_to(self, model):
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
- for key in m_param:
- if m_param[key].requires_grad:
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
- else:
- assert not key in self.m_name2s_name
-
- def store(self, parameters):
- """
- Save the current parameters for restoring later.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored.
- """
- self.collected_params = [param.clone() for param in parameters]
-
- def restore(self, parameters):
- """
- Restore the parameters stored with the `store` method.
- Useful to validate the model with EMA parameters without affecting the
- original optimization process. Store the parameters before the
- `copy_to` method. After validation (or model saving), use this to
- restore the former parameters.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters.
- """
- for c_param, param in zip(self.collected_params, parameters):
- param.data.copy_(c_param.data)
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py
deleted file mode 100644
index 8cfc01e5d..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py
+++ /dev/null
@@ -1,264 +0,0 @@
-import types
-
-import torch
-import torch.nn as nn
-from functools import partial
-import clip
-from einops import rearrange, repeat
-from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
-import kornia
-from transformers.models.clip.modeling_clip import CLIPTextTransformer
-
-from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
-
-
-class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
-
- def encode(self, *args, **kwargs):
- raise NotImplementedError
-
-
-
-class ClassEmbedder(nn.Module):
- def __init__(self, embed_dim, n_classes=1000, key='class'):
- super().__init__()
- self.key = key
- self.embedding = nn.Embedding(n_classes, embed_dim)
-
- def forward(self, batch, key=None):
- if key is None:
- key = self.key
- # this is for use in crossattn
- c = batch[key][:, None]
- c = self.embedding(c)
- return c
-
-
-class TransformerEmbedder(AbstractEncoder):
- """Some transformer encoder layers"""
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
- super().__init__()
- self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer))
-
- def forward(self, tokens):
- tokens = tokens.to(self.device) # meh
- z = self.transformer(tokens, return_embeddings=True)
- return z
-
- def encode(self, x):
- return self(x)
-
-
-class BERTTokenizer(AbstractEncoder):
- """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
- super().__init__()
- from transformers import BertTokenizerFast # TODO: add to reuquirements
- self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
- self.device = device
- self.vq_interface = vq_interface
- self.max_length = max_length
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- tokens = batch_encoding["input_ids"].to(self.device)
- return tokens
-
- @torch.no_grad()
- def encode(self, text):
- tokens = self(text)
- if not self.vq_interface:
- return tokens
- return None, None, [None, None, tokens]
-
- def decode(self, text):
- return text
-
-
-class BERTEmbedder(AbstractEncoder):
- """Uses the BERT tokenizr model and add some transformer encoder layers"""
- def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
- super().__init__()
- self.use_tknz_fn = use_tokenizer
- if self.use_tknz_fn:
- self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
- self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer),
- emb_dropout=embedding_dropout)
-
- def forward(self, text):
- if self.use_tknz_fn:
- tokens = self.tknz_fn(text)#.to(self.device)
- else:
- tokens = text
- z = self.transformer(tokens, return_embeddings=True)
- return z
-
- def encode(self, text):
- # output of length 77
- return self(text)
-
-
-class SpatialRescaler(nn.Module):
- def __init__(self,
- n_stages=1,
- method='bilinear',
- multiplier=0.5,
- in_channels=3,
- out_channels=None,
- bias=False):
- super().__init__()
- self.n_stages = n_stages
- assert self.n_stages >= 0
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
- self.multiplier = multiplier
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
- self.remap_output = out_channels is not None
- if self.remap_output:
- print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
-
- def forward(self,x):
- for stage in range(self.n_stages):
- x = self.interpolator(x, scale_factor=self.multiplier)
-
-
- if self.remap_output:
- x = self.channel_mapper(x)
- return x
-
- def encode(self, x):
- return self(x)
-
-
-class CLIPTextModelZero(CLIPTextModel):
- config_class = CLIPTextConfig
-
- def __init__(self, config: CLIPTextConfig):
- super().__init__(config)
- self.text_model = CLIPTextTransformerZero(config)
-
-class CLIPTextTransformerZero(CLIPTextTransformer):
- def _build_causal_attention_mask(self, bsz, seq_len):
- # lazily create causal attention mask, with full attention between the vision tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(bsz, seq_len, seq_len)
- mask.fill_(float("-inf"))
- mask.triu_(1) # zero out the lower diagonal
- mask = mask.unsqueeze(1) # expand mask
- return mask.half()
-
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
- super().__init__()
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
-
- if use_fp16:
- self.transformer = CLIPTextModelZero.from_pretrained(version)
- else:
- self.transformer = CLIPTextModel.from_pretrained(version)
-
- # print(self.transformer.modules())
- # print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
- self.device = device
- self.max_length = max_length
- self.freeze()
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- # tokens = batch_encoding["input_ids"].to(self.device)
- tokens = batch_encoding["input_ids"].to(self.device)
- # print("token type: {}".format(tokens.dtype))
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPTextEmbedder(nn.Module):
- """
- Uses the CLIP transformer encoder for text.
- """
- def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
- super().__init__()
- self.model, _ = clip.load(version, jit=False, device="cpu")
- self.device = device
- self.max_length = max_length
- self.n_repeat = n_repeat
- self.normalize = normalize
-
- def freeze(self):
- self.model = self.model.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- tokens = clip.tokenize(text).to(self.device)
- z = self.model.encode_text(tokens)
- if self.normalize:
- z = z / torch.linalg.norm(z, dim=1, keepdim=True)
- return z
-
- def encode(self, text):
- z = self(text)
- if z.ndim==2:
- z = z[:, None, :]
- z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
- return z
-
-
-class FrozenClipImageEmbedder(nn.Module):
- """
- Uses the CLIP image encoder.
- """
- def __init__(
- self,
- model,
- jit=False,
- device='cuda' if torch.cuda.is_available() else 'cpu',
- antialias=False,
- ):
- super().__init__()
- self.model, _ = clip.load(name=model, device=device, jit=jit)
-
- self.antialias = antialias
-
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
-
- def preprocess(self, x):
- # normalize to [0,1]
- x = kornia.geometry.resize(x, (224, 224),
- interpolation='bicubic',align_corners=True,
- antialias=self.antialias)
- x = (x + 1.) / 2.
- # renormalize according to clip
- x = kornia.enhance.normalize(x, self.mean, self.std)
- return x
-
- def forward(self, x):
- # x is assumed to be in range [-1,1]
- return self.model.encode_image(self.preprocess(x))
-
-
-if __name__ == "__main__":
- from ldm.util import count_params
- model = FrozenCLIPEmbedder()
- count_params(model, verbose=True)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py
deleted file mode 100644
index 2a7a73879..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""
-Fused Attention
-===============
-This is a Triton implementation of the Flash Attention algorithm
-(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
-"""
-
-import torch
-try:
- from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
-except ImportError:
- raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-
-
-
-def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
- """
- Arguments:
- qkv: (batch*seq, 3, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (total, nheads, headdim).
- """
- max_s = seq_len
- cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
- device=qkv.device)
- out = flash_attn_unpadded_qkvpacked_func(
- qkv, cu_seqlens, max_s, 0.0,
- softmax_scale=sm_scale, causal=False
- )
- return out
-
-
-def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
- """
- Arguments:
- q: (batch*seq, nheads, headdim)
- kv: (batch*seq, 2, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
- out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
- return out
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py
deleted file mode 100644
index 7836cada8..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
-from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
deleted file mode 100644
index 32ef56169..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
+++ /dev/null
@@ -1,730 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(30, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- elif i == 1:
- image = add_blur(image, sf=sf)
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
-
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- example = {"image":image}
- return example
-
-
-# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
-def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
- """
- This is an extended degradation model by combining
- the degradation models of BSRGAN and Real-ESRGAN
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- use_shuffle: the degradation shuffle
- use_sharp: sharpening the img
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- if use_sharp:
- img = add_sharpening(img)
- hq = img.copy()
-
- if random.random() < shuffle_prob:
- shuffle_order = random.sample(range(13), 13)
- else:
- shuffle_order = list(range(13))
- # local shuffle for noise, JPEG is always the last one
- shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
- shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
-
- poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
-
- for i in shuffle_order:
- if i == 0:
- img = add_blur(img, sf=sf)
- elif i == 1:
- img = add_resize(img, sf=sf)
- elif i == 2:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 3:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 4:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 5:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- elif i == 6:
- img = add_JPEG_noise(img)
- elif i == 7:
- img = add_blur(img, sf=sf)
- elif i == 8:
- img = add_resize(img, sf=sf)
- elif i == 9:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 10:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 11:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 12:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- else:
- print('check the shuffle!')
-
- # resize to desired size
- img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
- interpolation=random.choice([1, 2, 3]))
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf, lq_patchsize)
-
- return img, hq
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- print(img)
- img = util.uint2single(img)
- print(img)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_lq = deg_fn(img)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
-
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
deleted file mode 100644
index 9e1f82399..000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
+++ /dev/null
@@ -1,650 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
-
- wd2 = wd2/4
- wd = wd/4
-
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(80, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- # elif i == 1:
- # image = add_blur(image, sf=sf)
-
- if i == 0:
- pass
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.8:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
-
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
- #
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- example = {"image": image}
- return example
-
-
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_hq = img
- img_lq = deg_fn(img)["image"]
- img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
- (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png
deleted file mode 100644
index 4249b43de0f22707758d13c240268a401642f6e6..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 441072
zcmWh!c|6nqAO8$7B{n3LV`kK(93v(n=FF9&gWOr7x#ec=DLIy6$XOP(=y2x<5$5{3
zs+mc-V`-Qp{Pz3DAA5K__ISMae!rgQE7jW4_~_x2hXDXMYHEV90RS#N006atxj3JE
zF4jW;AOJAMT(%1vnml1{bTxP?g+DiynQo9o!I6N_%E*vbgZuO|L|mjk7P
zI+d=K`&W>AKZIh#!o$NOBX`NMJA*)>jW^|y3Q#;Aq4n&kr^~q#OBBtfvCT(8H#W{9o?KF0OXT!$_mv{Kc%5DquBFg3b@sO7_q?^dupWPXl
z54e1i%uFqg$z=NZ`PI>IX={rkWUC^bXM^*czmHU$U0g`pQ7yUKjc+^zLamVJ`t&iC
zhXDc@z;14{=4mUN9YVU<+VqJhq?`3MyZ|P+*|}Zzzq~wlF8)L?v){TxVRY055O3&vbrg{
zA{o<(b&h;RX>9lo!|;7Uqfqe5%F4|tQh4Ef-*!PDFMfB=nY|a|vb(S<<#G>;$qqX2
zIe;GfzRJ$OsO?f{*~dj#N(O_&niw&AvlF|Go5O4z(*ri6szhcjMxh^?P*8(MDie??6!N&){dv4x%IdQ+0(SPrz81#ezRI<%+xlBmx>e#T6
zUq7hrDyIByUXJI@r^JW(+`^n|0)2ph+o1p$0O!!J-dAZDp@>Hi=#!fPK;CSaCn+CZSTJ0g!<}JmE`;e5Cp(i=ACVn
zB_^PtC~nSu#5ZmKw0!9DQ-eUj&+$%Uey#fQ60p2dp@#vyGPgUkqaQj<4;mnkq!R4<
z>0nSsT}EGEo)t@b(3Uh8K9?OV;3idhuuhvts2cgzpt(RGK#DQZZ((n1ihdE6u>jy#
zeGPt!1cma2s@ogNa|Qa_;wYcVy~Rb&)3N_T$+2w4TKG<0y~D(KvR1Cp1}_5BlREYl
z?>K>@efNTET9Ev0!oIJP54PB})&n6njk2EAfA?iq^ozsjoRPZ$-Fuq%Az8T?dr&4J
zSr9Ab0gvr8|hg#PRPNJDi*8$MoBXp|R<~5E&U6`0(0U>wh5lkAQ$IP>&=ijvyI#
zQ)1@f@Xt9OJwA9KpS-+0CNMPdr&O>%+(=Ikh6VmLF$Zb2b=Ud@+PW8ZYagl1g}ck3
z_yG9_Kl_|+B1~=6)ls2bXKXK5JNPjBjjA}0S7O*=Ogq(lq#!VmHANHemFTXi_};?Q
z;)N4_)pH^5h{?F~`FDrw$jAVPPa|wrY|I)M%-t6D)WJGgm+o7qdAQr_Dz6!G&DYip
zJMQo>XoUW=gyV*V{1)TMb6I7)Zh1;=)M}Eu`w|bjoKo;jTG9o9ME-o(6?T!?o<;L0zbKwDO9L*ayGU~X@-c8024k|S-(`b>%6F?fQo489W-9&-+-!H-tS@S~D7)(emDeqNfUd4%5MoCwY7A%P;gVN*-QiV5V%)Acg
zGI4HRwacrSgw3LE7!`Sbc)ETAXia=^S2;v
z{nYX35JwABdK)s8$}%?*Oa`YWrS2|dv>O5G(-`p$Kmw3?@o$B)G2CDeHHE{!(L)3<
z!FTv<4G0e1-Q2&gLa1*hmSg{A9K2=kPsHv`nD#oeX&VnP#IM2iyL~A_jM#%q@TpR(
z@YXlW&j`6;jM_Js*SG5%ub)x~6RcY|qwS>tCRBTS-6V#d-F
z8*KTw19N4|js9uRam^hLS9k#{{q~(ATa6%<-z~fYysr7aHhES>Ru#T5G}TxQ0H}F{
zE%JaFyOok{n20yL428BqGjsc2*I5EYk<-GLdHh{@M%@gaK)`LI{Q}Pl#M_`>K0yI0
ziI58Vc&&;)^(KTtCO5zYIxqh&cM2;O;=8ZxpLRBJl*(MC7uY{~ciQM&tzur#6{6(x
zqkwYA^$@p0G7+&+VlKclXQ|lUGnxev}0M9+aM5dipA{kGc>L?eyROxZFEvh0F4Bx-;UoyoB+(Z!(VuCERE9huC#1EW%2;_IfrHa}9
z1+K*l5KIbIz(iESDV3(UZ?L&+#A>*|baTEpQ=Pvl|It*pvc0WjWu*baf^+*HU;J?O
zCm~YwBwwgJk33349ple^+a0Q5%gRQfM4+(QTZFJ+;?(yR3OF5L({PLn7_(G+^%sdI
z$QLR`19I~pnUNIrIm*jFc;zmjGrTZW?zqy(2PSPVhUO#p+`$Jq8`ywxnRFH#^l>siWIkV0qf@
zJ_<8ghg;wO_fLE9N{!Y%^AS5U5MF%Lh)Hv1OifXLN9nknw}Qjr9%&Atp}FOp7b{dp
zqime?Y-PV??rJL`<=}QW>^E}^#wIX@&1N^(dO8D>w;WG(nt*AzQ_+67pt=lcT`DWv
zhU-T(Z9IfROE+0l)cook%7bXT-p<-C2pS*uIknvQv_iSG0?s8v;*Lkn1bm}|Tm=sO
zDG)(5?21P_V@++!-RC@<94QobG=s1eb)GV&!YeX+tGuGq*p3~Y_ExcPHc+cb>4iD?
zWjQuI5%VRjIrM;Qw-&_3Wnwm>mip(a+hm;b?62wF+Kh5Iyq$U*Tj-YNE7;BzKQx?@
z=gl+-`!G%f!}Ig=RAji~E`Mm$dtPqR+3q`MnV6o)84b*XpA2$A?7tt~Ax=IN17$DWwjh?vbm`D5{&R02=->sPXIk0W^ziEd?F0>N?xkfJvJ
ztEtSKI}tIP(eF!mfF&bfo;)8;GOZ5viC(`j^Imm@d#wL5v_JReF+dzY16IWVu43E|
zD<96yrDOHpVAZJ5+`EN=K0`*=N4l?CrDY->4W}wU#OR(V^H+lp7Yo_f#R0~;eA8H}
zJ~dHuRAT6A_>F7+L8$8!&2^n>=WKgTYfk7D&f8((0q@=Q2
z|BMdL^9|3-q5ea|nL}gHfI@lbWjIE>qr2L}^|}wGyZe}iK=CVYzZ&)hqtgh4Dl3`+
zg3ZIJ-y@{U*g8htVJ4GQML89g3a_Rn4^RB+RD|qI_5+iXmCEKe4}S0fzjih&n{x_4
zFaVx)oBNYnlV3<0=i;J*n3s~@mnGfi#kcl7U3D$bfZ4BRnTcVpAeb=8L@
zafoGeiv=r6t0>Hs(nLx%8R&WKN4un~g8880JHd{oK}u?_vG;bRV>FANDiyV=+8{lh
zCWdz-n#OT^e|{uD4!s%KjOaMa{h*r6q1AqM`IW1?EfgPV?^X02tS}S~HLVQRdS*#R
zaoF=6`*SbMgDi>mI9laN0$4?{@3${yr81iFO6#?w=Um@xRCt6L(sccZmM?8*yKjCY
z2DfWwzPd?gGny*%RwJWhTbUtzdSh{5YT7j6CEF3VTZ==cR*rusg)4ju&gJ4#J_66J
zgurZYC&iWE5S3EdcD32@2Nhaht;b3zY-=p~nr^`&~KOwC)?=({PcHe+msfS)ZUv%!1m8g0a64$exY8oud6U=|uFbO}S~V
zq#gn_ys@$};Sw7i9XVFwz2t2w3{RVKctz0wG=livL*ECA$_HxjVR(UHlm@pyHy@yW
zX+W2U2SZ4K+{^tQ=aex8YBTQ_17^>a&2l6&Zr7ky{r+HNNLeWbBJf?L11ZHK1-+6khzS}Vq-VcLd$q~>8ryhb&aKGV27$KBl
z?O{i{{~fY4Pt3OIMWgZQtKVy`8^Yii|4@5rFi};eqDioZFVW*d8x%O0I9NH@h~1Ii
zkHo6lhT7Wm5NKBY-Qpf+pl~=!5|4(#1;w!jxt{`nX+8U8t;uF~7j-a)9DXy`Yhi&>
z@knoyA1xOJ6L}B=YlBx%MZh1%Nj5|QJuEO?*=vqjm=k_{&5R%FLkSS&4YtI*_%;31
zF2so)UKlvg%r35oU{cieMcpLJ@>h0slJg#A|LW-DTZwkmK;_SGFLb0jFj}LwZG854
zpJ1GVk3&=c>s4HC+~1`6O&eicT4N+VqPDgIoacg8nlp-ra?#2=I9iwZZcEYN{K%qq
zS6HiaQDGtQV`T-$VB-zQcNIjmVDK)$bFT6M0iDCa$x#Qxtw6NyrJ_2VK_};*YKtt%
zIT=c<)W_BaHzyi_3ryyn#jQ@Zq
z%tvh
zsfK;^UoMNJ9L8YYdjx(i(bQVwv_+7{K|`P
zp5Eg_GaTAwCQ6P^klUIu!ra{P
zl_%p$&zd4nwVwwBDAsH!X&@!!H>F?B&deQphClOFrQP^a^erz~DWDKhWl&Q?zX#zf
zyA#JJa=C5t)6K0Nj#$3Jl5ZatYOkiRo#0
z`ujDD3`aR|gyqw_?qaAhdS(JmUS5z8kTz^|3YVsmD<^M=P*c|z#|R<0T)V#^I2tIBy-*WzAAkOo=WMdgdZIt<^sH`jsNmWi(ecDV_J
zCNct!)RMJVOzIknX4K-!G;2WA-!U$ni4)l56v-sqGE-rlc@#-!J6QG20ChBrZt-aR
z?$E;R6E)nQ7PtYjw%g?%;iDpf>kqxWqrK>kRsEwkxo-1ibaSwZs$I;PY;gUP7vgL0
z+aF>!LuFJNE~;2oL>+XHGm3Pc*i1Py_SaqZUq?UBHVQ@Ao@$@$-WuT?VovKnuIac}
z$}BIO)5N#}o;yB4Rv$OE9(J;9LQo+qHS_DIF}0;3jq?6}$@KO)-c_toCm@*aTB#DI
z5>#!A$wqvR(@$&{ekUSkgy8?WGK6l?`(BKXE@;p=82Zm6G{k2pK4Hu|CLK4|?@XL{N~S{r^rQMsSkIsBja9B
zdYzg4^%WO&oeEnP_3U%sKgA!6zsLyIBt7N^q45dAS+aR&Ww>5i=LK>7@qNR0B$@D1
z1)JY^c~r-E;)i|Y@=*x_1TQteud)mifp6$Ysn+ExJWIIG4g8sMWU8OkP^;n221am>)XP->-Ky6SCag
zNXjk12eL9jnMod#SK8qS5~)YhkO<*;gj9F^2QK}=PRy0)YLjdT{3K@th)YRR
zKg<{8%!v}n+|LkjIRZZ7~uC6X$
z;nw=Posa$4@d~o(-ZzgtI57-Ak
zqz~3~qj%QVLR)uFK-tawD1da+&!WFJx{1CzqIOAFmm7w92rk{6O3-R%Fnm_Z8*z>}
z9HVY|V?6Tsk8ELBBdukHLjZ6%Ay8puc|k_dNq%TQVBT*>H?PTV|95W{-;#lS1HK$n
zg2rt8=av`+Ip(XQwtp6YxqaC5PF_e>S%ttM@8g74zFyWN;B9(?^5%Yfu~()X4TBM-
zo$+5CHEN3Uy(zTXjA0wgcH#ARq)}ApvPwL51b$4>cZX
zI9i!4qP%E-C6q5OBy(Pr?66GNF17^s@Yl=Q_-|ltUzmaEAi@A_`Td23(Ttc$b5IsO
zf;lJbQA&zCtND0IXPn|;D-6e&5!K(HdhC8`H66FE^7`7nNH?*^pPvl(>Rq!|=bA6L
zo%i4FSj5O(1p)>Wg#2Ekaa>G;?*~&inynGbs)}K=n1KU8ZzrWj$HC0dhKtAlx;md4
zyO|@0R+k&cPHI&}H!~(2nH_WtkKt(cED(JYpPJnn1q76chQ53L3u|)5++>t)ed&8=
z*cmRHD@d6VNZiFEj`$Qf`bGBb+*jK}Dn^W2I>%I5K#ZoRBUV4?c{x(zgr(b|ZP{VH
zvm9Tgz_NLR@<=N<4LT?&E4i*vPcqPuv`h@>z;i#$J*A03g~EPfuu^ys8d}1Q#(yW|
z2#fJZYk`q!PZPn4oxz#1<=#ewms{i=HlbKaYP2VgWPT1O5zK$i8r;@V%1UvtZcs3uNSMKL;CSd;p
zeAsGaH1dE|bRdye(7fvLwU*Lc*EhQzrIUYmLD{cvd490F%+rTK{SF2MugTX_@xQtSwR~v~ust7Tm75Z1Rq^
zYeor$Gf+;_O>eo_9_mC8ukeEc)~$D2j!J@uB8Boavbj|rCYE0q&``f(T3)d}T-VtB
zV|iMCVUAL>(o&-Xhyxavw&I7ZRBS}~F}Jyb7A{O`zd*d8vJ%ZH>X<<}Q!~>ugWFLz
zGyiO?Ebr24R@Jj0woFL@!E%|eQaoZjq8g#&7t*pUS>bu7;Y(#z>>A%DH`u{_@VWFK
z9U=9LU@w{VB1kbOM~h!L3C4wbVrYlKT0Kiz9qCT%q0o^SKh#f
zU$`$_gwoT-+uK{H17|RK<%`Vyd0j5o>}&r1dI+H?RXP4Q`z{LdiTiQ@T=_Wvprmw2Z45H6&4q24rIUt8RRa;Io;Cm=|e^f~8Lk?hc2D^Gv;D<^)IosB<
zEQ9Z_SZ;qnnd{K=j-NvuJX^V(+_n+4xESBIyfY0ipn42gPIlYWxmKyXtcV***E58Hq%{_<*Ce_{!ZG
z^~;pZyUDD{5CpDrsOVr$-`zrEAE3AyH7vx4zV5h8ImeRdAK=8Evw`6ejj%tBzOg$a
zMGihWWY%mTClo!!btqYEXRG=(j?%p#X0NPS*f$b{Od>hFsuk2hiO
z9v$Y0O%CwWtjK0
zHVAfx!4bkmIx!BGEb(KRnLH=_Ch|!o5U$VFU=u-zuCg#M4Uzh(xkmoQFQV1_0CoYzVSvNA75yQn@oA8SD__2
zLt1C^O&u*H4QhC1Ui8qtG^jxaA)DAeR9D9#_veXS;wo=R7aN*7w8;l^u{#D#NvNP~
z!DYLvAN+!T#M+Cs_Pc}e#c$>S@#tfcxQj9((%fQ~zs&Z><&sW7fleyua>|!8Je@JU
zXF6(C%%2#I#8HmYPhIeY0a=LZR})=0$2^zYy0fYzp#-x6i2(ZI%JN3v{IQZ-1LSbx
zi1yp(Dz4{kO|R7@>*b6Pla_1q8cC{LDTM;oH3{*D@+|~h!C%B1&CK=u2<6V>
zF2?tg!XG4YNa$1NCt=k4%AlFqkDU_VLLe}N4434Eh-D8AYxp1<`f#=Xvd4^)J}X?O
z$SR~NvZ?L@_$uApSo`7Hs#Ku_5R5qu|5kVIfg=Yf8rOBY!~>{@K5{|MYrLsx-0f&^
zXYcOpbGX^{F(GN4OOrWTU9k27+tCYQ0%yo0NdJcMp4H8rot@3i@yLVq#gP;tX)~mi
zl@(C^h8;Fwp^gbyjnR5G!*X~!qIQl@6}!(Wirw3o7WCZ=&z|_W!baSTJd;|f1
zk^QoBO{-?y^JaOt+Z-pzq{KD!v$T!w%oPN^yzujk_A|?QR?n@2zw^3xh#b48>-fFp
z&CN}*2N?xHZAaXQO$;V56d4;EYt>Nv7@U7|z|h{9Iq}Nb&((KfDB@Ik5E6OXUFU_i
zT^;V3f9*Z&1D*zxfr>h*>3l&7Wwkk}T<^xH9o`V};+DLzR#boDFR2Lh&i!ghk>vl+
zA_<*N)hD^+1f^6#7(&B9ombQT(a#tcCXraNsUj*0`VdFHu21Ne^f&`ceyNyDEF++!@}JHKEkK%*<+f>{lOqyn
zJc*p`e*XW*zZkspch+a9>*~OKxTz`ND&RDs?jHg#lvjzYtl5~NKZ1}sy^a%;lK)%|
ztYUHZO;UbbC28NQndbG+<>FsE)3YWi<0==jYvjadH~mBH@N2bwRbHOO>2$$LSv4g=
zJkJ+_u1@sZCYE@#<6dp66VuO8(jutNoS&6QjcRhJdi?FgivHg;=iqz1w;!}cwNm`5
z?3$ZY
zF}e?pNej{G*BdgXEvK6Z^15yn{{gkNExIgd1^c^YLBz%#B9~1*Qv1{_cBQ!3*+E8~
z1w>NUND^VU#n`+{99MWJlvewQ;NVjk(R>Yym@8nl-~ekg_qmgq0H9zhO=@_A9h|4unbOF}n5RW(?k1s6#P$&)A9&}ft?Z~85Q<0K
z`zj_$&Y`>bvFz_@wR0>r5fSBb#k*n<2?~=Y2vE6z33do$N!y~btY!|Vd>V9F-z@-z
z@oKKnw?v$6Wlxm?vyorELe!=ws@t9kR=
zyUf;5_7EE`6}sqhART+y=LUGN#jWUSFt?@}YvF-ZEntgMKdL1NQT%H-nfi4ULZ9qO
zzmaUM8a@Xfxd{6~Dx^U!Id>*+YQ`HRJOG@IO|Hc;lWds4OX(Y2
zu)MtVG`;EKB@Z5@-&DmCQNk`)I^iS+k^V*ibk*Y1v)qixstqkISR)KPS1?JLSOua5
zf+nV9OF;w)>y(OFgF6wffIBE!%Q=094}hClEl8qsJtH%_g+X(|LsK(xD8GZ
zOpMl}sGGux71`NAFE{#mg}EBg0q#xK6b12*F+)ZLX;pqz
zKwGDq&!e=W>>xTjy2?Z}V&{x7^2Pl8eD*?Ai@9wgujH*O1yIl;_{zE@rG^vVFFffI
zUwbW&%<1za<>*8(B_#&u$$`j?3(&h_-Qp4c`VARE;jIEb!_QaPYckEbJkm|(vE7EL1mpFU(()@41
zMWq_W<(6{<=!q=4Opg8+BpLA=#c3+~weIhP=RE`u
zdKQ)=XA$k-eG6Ly%teq%Nf0q}
zY2gCqzs10a2rZ>~Qj*Wbze<>|=8>m%os)=e8hoc*kv`Wk*HQAwaD@gv8=<1-&Tk-At7
zxzv7AFv|Iyx8uSD=-+*gVmNOb64!R{P86>YR6tb98O951r~l5Bl@3{cxv-ijDsvoSP%T)a
z{Infv<@O)F@n%Ya%zKt+jN3K;6@Q*P_#~n0nIuip4{Q6=&!Zw42Y+*D%RV6xp8BdP
z;LnGG)`P9ZzfmzU;ikwsElw-MnbGpJfM|_u7?b+i*z_G#2p(
zzktob@edHGGG%AqiM#3JQX{YgM3nP>8rBtXxt
z?@*nqieEyp+Pnb>e8iN^?#5Ny{o_SVF!mTIwEd
zVNG%<%O;m|ad{juP6c^3a!965e_vEn
zbCVs6jiRCL%47pLR-JA#IYjx{%)}52L}gptcqGhN;odbn$KqLe|_5Y)~JmT
z3Z?c!ul69z9lN};nob@u9P6&`n~f*1mlX<*s?RH$js{oJMn+!z`bcLQbaV2!`g9#4
z!fgQgY>+&%%?ba9BDt#-PrLV`AVI7ZoOdPIGxW&dBPC=u<1aD8QTZ~r^~7lUpD_lwElgI3#V7i^hoR5u6SPRfiLqH
zehPbPug-hO*6L>9dGC&;`{5Bg`zg$Fxl`hh+tf}-y|2^qf_F!wMkru>%C{day=HDM
zWs1%4V1r!+V(%L_)!ihWm`*Inb|Vd);<=vpNjTjki!l;>Qj
z!YTfj6tDd}HH_J68;9wA5fA%!s}l4BJb{w(Z4Rhs*qObmd&@Y
z|Cy!6YTYh6pp7d$hDtT6Y7}$N@w|5fWCKGbB%&k=ee~deG(QSJ`m=IBQMGxGU;6K|
zgk*o)((WXy#4fJN&v5TfB7JgetE0Hw$_)P*x8PGl!cj7}t6%
zh$9MCI$Fv&UiDA8|LJfzN-0@RShj0MgV9JZvc=!zCe%
z#0a~=6&lPvg*D{hwjSku+wTI7iVK39j()vn$*GBz-wj0h`_xpVd)^EjVAE=RclI}4
zop`ylcb_(~yZAR)>)eQ%$otdWDdTw{F+JG%7rzQ-%z$a}J@Lhz>V!lIO-=V>+{L!6
zlIfBFy{}7+b@z2#_Wx+a{@d?naz;q<#~51eR!G`Z#L=^+q`8s6{dGF|?oG&Dh1p;S
zPFbGe?6TbQ`PRnla!%buonn;Ev!t6LxoD{#y-R9=~+SA3Qc{QQa*G-77iYYU^X+}T!-GA`%ItURE`+*4{T-PPqimDr45Cnr)|iO!aNaiB#`lQp
z>T{aU)5Hl2S_?08U-Bd?>nvBEtsUwC##!KIFVHQ!Gte^(
zK|aWl_TH8KHep~SeL}#SSE~FT4E*aF1!P6EB_<&gfSu%2SMlEeBATmwdbZzD8>r9K
zc3k5NZcv(Aofyuo&QlPy(dSyMPqd&A>jop7i|O@Wwcd^|M_
z(165SSlgm_^du{v>z!$z&V~73=Wd(ICkWWem^Kisdn-2fTAcfh)3yXn2ztDNx4|ZE
zQ)fo(=DrPQ;YkPy?_Z|B5XW7=F4eMYSIz=l;KvXy_eA5%Jv|^W(o~Q-)KBt6KYJRU
zM{ZDLsVXHF1l=q*EiY*DW}Jl1s?OfZMbGjOpnA^BIu=1l&kwb@5KiWUyX15psGq3R
zstpOk+i(gbR#wM}or)NVHPuy1s@v-0?8#<61L4;K0Z-NX)%we7?zg%)R(bbQi7d52
zPJXdsLXDprNF32_ZEa;wR4FMb4Js)CQt&N3njNPUwz9D?X4ju>yT3Xj)VYrAv6~y`
z@LM$5=I`z`!x$L@
z7`t~R5v`nJ{Zz+PJ#!c8cqpvl)|}^k-C!tRcCUF_v;d&=BD)|fj5fXzQ&ofhI9uSd
z^uFx=D?PFM{|%3>C_7;-0qbT{cXc0{bxp-DPb5pNVYkH(D`hw;3E|bYp*!5c$~@m%
z&Dj1O<}+L<1wG0U<)RR~(KJ^u8nIEX!z=ti^>4?bBC$TvJxR7uZw1dtg}~%`woO_#
zQ?~YlwUUe$Bbt+i|D)Ppy0jmV@%BHD=Tq#H5%4WKBWrw_zAFlPUXB#YX#p|i?l{Lu<
zA#!*MYR+c!_uq1))NtDr+8~KUfBC~HzUy<#N*rX2Xwr9IS^P%rRrwO+`5@
zMN*a|*WzuSh?JIZN#WW1Kcs
ztD|6(JM&30<=dL=sc4jWhRTlkYcm5VSeU?L^&0y$aDP9gNNI3zd9T)&z3cGllY|V{
zuRjZiP8cE{e#!o;t(4Qp8X2)gzQ{Hgjk)4xiGj`OM6|ZJWGxC5j)=ZKrjlbLv2ed>
zipj1J#qI6wHP?vAyN5EPO$JUwF}I(pq~%(YZDan}cYlLoP3K(O|NKyRq$|{tNFv`o
z95YKReOzJAuoGUjOmtH`GEgz@VD_La$oVNpkuqBk_BnjDs>*L-*%22~SWcdwZ{68*
zc{X_3U#MZag*l?Ox6f|nWRVqYvutPQLg=tLgTa_QXCF`aC-~-o)fMFD$X6Ca4JjE
zWzVUKtD0SeHfM@4iy|
zaZ}SkVNdCUPTZI#-p=h4$JK{O|Bf9^*%;92TkQ
zmH8U1)hpczHoA%)B0=M*7EeBbQ^nc$Ff7Ub
z=_k|~0fhNo+QcBo)LY(Yxh}T-N_YPUbAN@gx0Vrm<0;zA$2_jYDs?R48BrXj!
zmB|MI8?Tp?TqYfXYmyo-UX;%?oC_CR^Jj9ao_VEg^`gLv+&5Ceev4B!n*ZfF*O9eJ
z$%y>7>g8d;#s6!S=XSC274B)~c{q|BZrNE)Uvg#&KDAB9>7_(>s9U3SYgOxiLKSW=
zVc-R4u(#U%4u37M8BijRcsfo@u&X#*P~{#smJ>)JLvZuVV%WCJy(@tSVn_U{9w0@~8blJ*eIC6}lPb9h-4y?Zr_@wrlZBKx
zWajF%oZ0N4ikg_cotS24dUG}>&Xk{SWZNk753>HP{p`-Hd!B7WoN`pWBvUG?sy#L_
zF%jZqAYh6SykXW*#SWp7k>u=N?cuCMpK{Hvg)-TCNo2aAO<)4<;Y$XFP`T63eFT6u
zrC_iQj?Csd2k2XB&~2~MOSR`PLd%61GX+nDj5ocGK2@AaQsvT-pBWSp%Oq%8aLNXz
zV>9y^(Q>=a#u#xDw`Pey5&Qy2srvt!=U)sGb_-_IQZ{zhc5^s^=*Wm_^3-O?E8I(q
zAWK`LndTKwl1|i4J^i{~ky&_z4)pO7%m{?!m=g|>Om2zyw+)tc;N!yo^0^iMC}&um
zhC8&iKlNFyJou|@ka;%a+t?$5^jmqNu<+lv-5{GnP0Pz|#MABy=7*d!$C6|0nV@o@`HxGH<6{~nk-
z-$`N|K6t>ZGb$Ue`@_|C`FYIw2nC1wcc6OJncAuSzsnnqtGw$?oZtF->~3A`Mhc_<
zN>;E04o}5om8St>_B~lA=EKdtxz}Xz$L3~d
zwe_Tdl23HyUC>jV^_PQ`7&|DPxiLh6w#TKc1E~bj(G+R)Exl=H;nS)9YH68$)^D5c
zw^wUPJQsCGv|?V8YNx(vsn);$t_LK1S#Mu6QN1E!TT(#y0$hB2d?qJQz8!(|l=}L}
z9t*elqWPN7GuXsS2JrwN{F>-yH20H=tXe~yI^a3yA+ETp1RzV
z=H=c0I;qFW!ak+a^sf!ag)u!0=T`Mch@2Asq4(lOhAVt_cKfHDWwh5Td%Dd`P7aI3
z+73i31-Y3eetQOS^Or>ma(r{X|Q>1-(Y;1iOMsEtoNGB#obi`aRQbvybt}{)vrPE)vV)Hm
zKe+-Dz;kYj$sv#)xAM#Hra|q#?e1QLRX8wldF31fK!s|~(#B=kgIbs=gGe#I{}<3H
zE5J1$&N637X4-S(=o>?3Nc5oX-I|q&<^LjsQm#4nJZ`G=E)gv!V8Lg{xDp+N`J3&RmR8vzD;@<(
z$1VAxA!#K-^LUe9^y~U8GaZXTs_;djNIz&J^yzuAfIolsGgKm$>vp5p?>BKeuK5)$
z95EUbfo=D@D~q*E98r6inKxA%LaQ4#`U0PsX>3A(5^=bi3+g{_JUit7dVu@5rQDOw
zhE;a8jF!H1S(Ch;yTf@75y~cO7h%D$V1_zWG7QHTS7Hb$>&*fTtxpt-1$btgG02n=evMl6&G(Q2ZiT
z4fIfPTb6yH@i*kPQT4AM4&46LVnKYoX`&0o7j-6iuz??jMGF&Tul5N*x|GX)x1GFv
z!x=iXqkO4Y+bqoup)B{6C-s@I9@pUX)KWbqdYThDA8>Y$H>>uyQbuMKQ~JjVU=T?k
zS2}E!7=OM}N2Kv+(w|HL`-@LUID1B%r1i_4&~?Or5yp5O-sI>)(cDyzs$*OPbpBaA
zu9Pn`fn{!@ZYp!)z4`#~x8tsubSb($K!eBsoQ#XHaNgWqQ&kz_i3Mx>Q^OTL$3VvN
zCMnx9`G3X=2z2C3HAE;M`OVLv8A
zL25qjnM*Qr3vK`Em7HjawM5F@xA&wvN2Oged)PTonQ~}-e6Mb0Glpq;TY;QC;7ipc
z^(?$S-`+p=sr-K&opn@`|NF*AH*A0i(j$j}G>j5qgtU~TG)gx}hs5X*$$@~*Y&z8P}}^mBM(6!^$FMq-Ti^YIk9?i+vD)I
zrB|05(mG^NHw>=E=MO>z4aF&4hf1o>e2NZqvFo;9`&0V{>Tp46C7e)e42f@0aFSX<
zDRsIU)J7YWsz(Yb{LNbul|lhAp>DvB`r!Tj@-WLXR4bi}3y)a$0Vwbo&{J0~<+$7c
znYQ1LiOWbYJZUU=_AJL+8&Ft*Us8+=8aSlQ26e5S`$&IC&uPd3T*C_sHDk0-7J~q}
zDYs1TYoojMzj$@HmcBDOMOe!|ce`lQuWbkR1j`Bi#Z-u@9LGZ8EkRWwYyOD9&``Lg
zVCdVN!ue7q4Ook&ClmywIW_PSWEU1{;t(n(7={;LE&;FD)j|4CDXvQfzH3dZkI3H1
zL}meo?mK^suXmLzRqsfTfp13*+DK@aYs{VDl=u~+>eeg0MijNOc6wzbyXj9v|EHvz
zyCce{_qXqJFs3G)J7OP8QQrF>vM0;7?hXNiE%Aiq*WNJ)E9>|B4zWuA%%ZXflCyVT
zne-pjViA{z_`m})PR@w}bhhwI%vmIL21y*IY6ZeV&nQ9KQPue9HRt&KGeZIv}6$$&)}4FW#S&GISW+
z=a-~Fzk!BGGA%99h9hueR6yPdR|&m8eRO?JJX{%>%yjT@gk&>mS#cDN!_&@%Pw{UM
zWpGG~<6GynVY%Wy1(MBI~2g*9N
zve2uDAX9hM%BfQxEZ`@rt10X07K9?fQk6d()fE_!;>L4DN<(!Oe}znF)+Mc(Ssvpf
zvYDWwGao?DIG#i&=Wc=p1?A(n*{S2`B<0C5C+gjhmB_c``D%U322{_Td^m-ovXNAL
zXK5IpH<>Fv`9=TjJ8gHgyh|1}*Ve)A(cXRxWcBMp`_ENf&sl?|s68TkiPzbhMZI3^Jn?kl)@}
zswidvZ+!;P>S|4;k(sEB#1owvAUoLlyXk@IuI}ZJAfD&9QYa9AJn9~9nn?l#kgcEH&zVjh?|`H9p27&*b&K*4=76h!ywvucOM8
zwU60!$rd66f?~ruFmR9x;7mt1e(euQTsrjYS`o+nfs^g{iVoymdlLvG0|{O-_YudH
zpG&mn!o8)R9BkVc=mAl(keV3-M7r7QpJk)(pYb-`8PmdD%2(W%fE(`EE-?_sGR_=W
z0i-xzhzJm9{#m^kThny&>M@ONycQihO%f@AG>a}ZE_*B`*Hmw6dOYz{!g^gZjl=>K
zBsl23az@V3^tyF=hKAqebS#c0mVd0nUyLX23;v6lRaJDG+&Vt9Is(wPT7F$NHLa?W
zTTjzhI9e?zslvFv$szxK!5?!2o&5`^0fn0tMkwGP(Ot-Qv)S*xa8G{y7eW?E9NM2F
zBZS8x%cMykPJiMV9&>tW_L4<}f=EgH1Mg22RX2JmsTLa5SC6TQH;|FmM@YXD$Dbf8
zw
zJRwnGb|xkApODgIP*jl#j)(INB_(1Ezn}IX8t;qs4duez%^SJ?%u^&=o)YIqtbH$N
z3`PH*(~4ETcX7fxqjC6{%R>#CB@!mJfZg+g%hhF^B=+HvVHOjA)A4g#m0P4C=P=^V
zzC8L+*<0pMRp-0&CtaG}_i^^G=$^+>jI=7aaKBrWe%L1N$Fj{erI181RU)u*En!3uvZx_=`517fkA8Wu(i1UXUw5#Kc+d*{xx4vzMZB
zDh~ZpTZZBy@<6s@#cw@gti5{wE;J=c`cxXHa9~VqQ0n6(Y>R%vYXU&_EM0^Qp?Lfc
z&@?tuV=SuKj^A$X?)=)G?EKH|281?jazbc%Z+kwivQI01-`uo?
zELAHiz%fREE;+P|6=^ZSUkxa>Cwsb(c63Yg7}xVk48RLY2mDkezgA20)|_0^78Ek#gr0MQ4z*%2
zs~{n+XA0gLoZaETT+F^vGeEge(2t*7?(Y&)h@en&)yr6u+r~
z0^2hA68%&{tgj!b)p2pYEk2=a-t5ZW15ewUkiX%b6Y5sx#`YOMC=e=+4Wc8q+2UbS
zKrlqd#gk9>P(FQe;<8fv8|!u5H~IALzKk^!MfJTfEixh{T>SJ@XBP+yYMX}>73{I7
zKAic~*~(gBS@#8S8{tm~w&NY3sXZrP0~wBQ!YL~NI|bF~pdBKaxEnUUJ~g=OHmGE=
z65Bxit|-s!C5Qk`_xp+-pJaU5yLWz{{<6B?U}C2?5hDWE;#mX{3$<0zul
z!Sj`W*+|$kZ`s&rlIF|oKr5!^AH+vy_H}c4Fx*^sDJG>-4AES?@x(8?WsO_J0h8FCUGo1<`
zK4&-dGfe4n{HQ;Dulx6K~dhb$zHJ(Ed
zjErQe3-d#}`N##|yW1t;mdANo({+E5^6zg7`*iXHAwT@Jf@0qJE77(KNiFpGYn9
z%Kc+giry>VVCj^OZ?m`
zK7BcGrf8dvK~YtLo9!1sOV|#u{+VH)%dLO2m1Sx2cdL)8^pV}~ru)R~(uyzhX8Smb
z#0hB{{ZDDAA!PraTq^w}A9|*(?Xj4?UPnO>3-$`fccW#0;*he#E#?lP+)sv#pMZvc
z4xFC){#7gd(|1fvxE@|t2>}VshQC$Y$5Ft6Yo4797n8k|%N>xOu`N}^6}#oGQn*}v
zc)K!`^)c-BNbCW5)r`k$qRWl6iGhA{g|{c}>qO&wL+T<#WPBoxto<=8-c5K{TttKl
zD&C)?G!2^WLfalYjSxf#|J+E^D=0yw5p9j>na4i@)iY|&WH81tWfWen#2ASw
zNq9)ji^JL2g>a~|`Tl?yx?^l`W^jdyP3RNg5_$b^iPi}>1Y=#@n}RH=<|F32gPF9R
zEe8#q<8miY@xog6
z|F*A4xQXSwiOF0RDW*i5b$bq*ARONDh%73bfRM?TEJ;C2LR>?n4*NWuyLtfG&z}EJI@Vm
z8NO7OW&oi=sTimT^e~9APaU>i-Zue&O|o9U{JXW#b-VQ>Y_;)lZ|~2UkI^|WImVhE
z2g_%P4A_x?Nunw+ejTg5F5uWb$vyR70?Kp#*rmft=?^JSo^u+|_X~>(C;ZaWE~8T#JocVWSIm)Z
zc@D`$W~65Qg9ZyP7x*qm+~X*oU{*C
zHYYg1s`Of2p#iV8XJYMhxL>xf9e>JAh&*fpU_Pt46Eg;X4&u=lu2sJ7N7YXJQ6SjR
zN`^8bwi3o}t@4ONx>%`{jyPQgN;q8ZVEbn38&38l_M7i5;J#g=dse9DbxI`OiA63L~qG9!vp
zdVSU}BUGP#_GHEUM9zv*+}R=9SYIgFvDb>K{?awGp+zcHBoC({iPZ2Rs7IIs`b89p
zIO#_Z<1ocknxh@1ZU!X1O`$P6t18rhhfP(fSoQ-T|KFbMaS5}P=g|~KUrs;|N61kq
zxmk(`nXo)XVv^muATeV_MyE8E2e#^(4&n5pB?Ifh(ymLd%%V!$^4Q{~%RTLQyh0|Wt|Lvxn)I4w`@ZhBOS7P!k!AoUU
zP3CM7r9bPtc}S6tgWx{ia7x+BMJgQL`|QKtB~{QWEIV5s*VrchaQb@+8BW9Jfx*ju
z5#n>wH#jJ>`P1~wh;iiYg~gS!qm)?~F>YESBdkpv`JSQ5}@iRVlz
z<-&uza&KylK>BdZY*QrZ*$EYzz3V$V1A?esU_FfzV!*PxWKXAMX
zkiuDs;p_5)5qRUH6&Z>M*Rxi4SJvn1>h;&sx$LC8UxWic6K{)XkwNEv%wy)!%BdiB
zQVs2v4C>c!XnnUA6Zlp7`?sxZ5#WsEB9LbLnCO$TRWs-D6;9>G?*l!@mJ9T&V5@?%
zfZTLWhd9lDLi6OzZq|G7dBzL*3)e|53&AWDknA#9I0uBLy^cInn0+n}ck@uV#70COC>k@;c%GnE3byXf3J}X;M#_+9+
zJy22WCkD*!(zE|1P2aq!3}K=vilp+O_%c_R;x+}D>Rx%y%tihdlCYrw?*lx-aV3|Y
zLVl+V-y(1*6+^p2(hM2i&)BNnG&WCzx|2sQ6yBu}vxrH`+;VsHNb*$z`Go^qm8BoWZzxc9=;FVscykpm!q2ZDo%K6WoQhKN-9
z+B_=7qD>wGL`*aI2w}4(0glS#5+bougxYyP6rb}?s20@7XL76dC|HX-V;bdwE79@g
zRQxRO?D7EJfWbUHAml8BGndR}oZdnLZ!d0F-a+vZ-p++g7nRGDTJ+Q?sm
zaj7*o$8l{QKxzcNJjY&%d|=Y_ON`SO_)ia5K1bjQGQPA@exN;I(tr`g`#zGNX3@CX$`u?
zB&SqZIy(!cuMW@3n0Zx|Q<@D9N;Xgu}6JTIL)sGxk&WhT39bH>kJ^!dBn
zHp}2f1%Cub=tdz)HaT(0AlDv~$gG)Pt7ek;oZ5K1MoatBZg>@A2pAxqt$bM^9PXoq
zOWAU&=sJwG=&H0Fxi8#>EM3C3;9T6)6GyU|ao*7Gy7xj*vnUPRT$w-v3i02>UKs)F
z#4?_uAjOd}wQ>qjDr&EgYX$eAzErp>6#p_d5dxjL@N~2(<;IUe`j8JVCJDXmyb@_M8-wqCMkfZAs!yyn&nRG<=fj*vzQjm8EPMcZUjzE
z^qv$Dqc3*Ceu=uE3MJv}8+T2l9Cj-2yX?pbd^4x$Dr+iAq{t8OP8mgT*v=jbKgTx&
zpE9Lz+2I!!k;aX<6aWqo07shT8Ae{qO0Y7o}qvI%ouX*|rW|Ahi~uK@2IO~mr=&ch|(
zrx86`FGQnYPsgba*9p*L-soJO2OL!(kOSJ^*qU#v9hJ(aVY8w4Rpbf6!0V`ENap%>
z3wRmgT|ThNgi1(06}fPqvrAhSYv`%)g&Y=3~)YHa^M0OztQ##
zJw-hPGJ*#29Z`JP8G3cQ71$B4Ca4_Sc~oOdj=$LGY68$`ArU#tAxjrGtw~B>drC6?
zx!%)DJ3TdUpzPDg3B5lp)5&_x**+JtVkAo&^FmvZE|i!C4S{POIcIJN}@68g1y`oQDM;IwiOEe@fV$MZk8
z|Fih6Y3mAkNc!+dN-kZRJ+Jtc=sN2&@>%)s_M?WHQ5Kr>)L%(Wpn4(
ztENrUD-pi^6NSQrO%6wxMj%GnX`bEijvbu(ES%=32;a}25tQ5^qT$J+My+TB@@56+
zSn#jWUhw}Sl?DJak{l*wt149;hqh~j^z4H_SG8i*nZPePIuDiNUc}`DrHGI7K>@QQ
zLiXBf+qZ)wlCLtrwPU_OUt2R=Z7fYyv7ZwB0oJL}9kX%aidKetC?tSXZ`tk>rYUV#
zEdK`*ry8TR#%7Ij`GAql$IfGh&l=i-K3jl5Pc#vy9og`mTjL>LvT0Ii!NhCOUx2J6
z#%w?bQMqa#@XCd|NVC80)&urvjRGx7&WE9vae6tNye9z#VC!4}bsL>t(HIhz^J=@|
zOUyWMt6p_mKmo`DAxTlr%Ah&nZn=JuqTrlSgeI=y1Isla%1#A8I1qiB>6+_AI1Z=N
zAzX6^x2nYHuGdX|4)x_eLW_5)&5ClIpPlGZz8NvCf$`0!+x#2jFEK?Nv{ue&
z`Z1&QtuMb&zPqii?6MHy=OR4M;W!G~Bw&t*H5p#=A4yIDpxly#exADUr7N)9ux!F)
z{5kE5HFjh10r>471+%c{em9f7P=h@_qUIlJwIz+
zoX}AKx8c>c#x5*s^5$oXL0REhr?ux=V@WZ_7gv-aphBVitUnvTSkPY{n@J5?8P4zSNWKX5
z?FTTjze*Pvg&w~aszsSg#Rmr?`pbVy&;Hc(^OqD;LfDAC#G}}VXHy}~vU7;_z4Udq
zYz#d#N+Qa;rZ4^M;MON#x0tx7BC1a$;!B=6&7WoP^^aGPzT^M<>yoT7YgjS7I?A=7
z(1H?8N6AjZvXl2McuY$<(Y*idrBuaGx+wHnXD8@Ol6lv&cJ{iz#924%C55in#Y;6m
z3%8Xs5`(T0))|+Q)P-$jBR8F1aCY@|(Zf0qV-x9Ox^Wl)b!mV=9NhY0JyEDp^}O0C
ztL*i2>cp7b^HSA2@~Lm(&EcizE4%`uux~eQ0eE`cM2f8IY;MbKO%~I3_`stYvna>?SvUDA%--)p^$!iSU~;G2n}|e*
z_D{sLYIh7|^%3{{-;iG~IyyQ^GJvan&VaN72+5}E(bd@{(~ZS?^UkgaG&3|bTPG*R
z*eVm#Lo{cYQXOE*>1^q01+T>5;t2qc2>p9HgwjW%
zP1f%YUEhoXer|HmX{ZJO^)yL0uL06iZ53KGU-;w7;<6ETxd7z(Q%lvm7Bh2s5mI^y
z-jA!fGC~7-kJZV?h~^
zmIyLn-j;nJ=Fj=aLZb+~C89M0K#?1P4Dl99U2yE5W&Qns&od>S(?l7ZuZ)dl8Ed1q
zMxTg2uBvZsYmMH+VX$+c7c{{KM}&PP=p|qiV#DR&pAq1o9n(Db(f?p_<@!2qTv9aX
zq2ZR|_$?|*ZDfoF!g9p2v0YOsf6cFLV1umo{)IG&q>`6ntHgYnHxR?83KxzUuU$Fz
zV<$kgn+x`mD_|saciTE=zd6xln#ONfS!hlN3EAbNBB={Gd{%R^uCOy2f-UoYTPcjH
z93`JYSh0W|8+B5vzgMNKdYWU0!JSdNkf~RX+P*}U%sF&a!PqEXG;s&8Q}N#--!JTQzeZ+)~#wTxnprZ`G3SFAG0KJ5zhlk4$?@1+@D-=k<~(V`gdhS(p?8!YzMoSoHXgZDq~y^}|IS|!
zr!bX>4J7=A+!g&>795weZ5dl(U;4^Y?yhv=KMs0+g(F42yY0T=Og86_4WO}oW`Jl@&O%J;*cQ>h7wq^$kr+|VyUf|YjK^~Pne^SF(+r$u(M#BL`z
zvEsjg^wpcTHW_DBmgHK~?>%}v1*B)!nkA2rLS4~#kfk$PJQmzqt?I$gwKM&Ah#s(F
z_qa>m)vmb5;6P%m@xI2e0aHem*NM;DkdS~tlsC`@5Eu}GNhll7$?={*TBXHUEMWA~
zgm&7EB~3oVte&0;bIYir{AC-Ess7;xEzhgwjdoh3b|4nfgve=CF#XVr2a%Vs(imgs
z@fL84XZx(4=DO1eY(@;Dr$h`Z9YoLDgjJ<$R0zbd6|c73jjtXEY{LP9a!+nU^}Y=`
z$k?f2;B!EHT+ZU)Y>9T%3!#|WuN@5mMNP6(#
z1|SE$AfMJeaaMju>cQ2_$15oj);s#PTFY+ThD^N=IIH=W+uGm`#HJ0~38h2@$pUbAec
z$7WiYKS2A}qzlhn9J^|a;`Rw`z8eaxG`W7Di~6d<3u;(1KAT*VWt+ZM7GD!lok)Dq
z*}~quE|FKX|NfKxZ$(gDT6~5X2f;(RdV}iKXu)VBWsP}iHmUw_B>pZFJE%%ZA$I!}
z1t>lWe?4<9OWHIBa;#tyR~V=6Qx_wx{`f-mnK%{IgS1lOiP*vP7SaWW&Pixe&j77W
z?MeKS^#a^dc)5Ko8T&S8(zakwHlen>(8_*c%JAEsZ}9lxhF=q7G0o>}X=o|~Qi16a
znJwIP9=G16#q03NynTtVm_k=*J&U~+!*rm4<>0zWOG1K6_ch}?Qh^WO1Y1hjeu{K|
zf4b01P&i>i%L27oIL{kbdFkyzqhIy=Dwt(xI;d;KMN!?Ho+OH3I1!cW-9P5*hNLxL
z*j{If=ggcBAAy&4kMpXtkP=zBnVRMSB_*2K7fV3~y4Hx={vP-w{NW4X;c==yU3Com
zV9?}PY4-{_BU`(sC0>qONO~KLAP@RPPp^%^>2=?Ll{H!2;8l7+MI#~%#n`Fjr|6Kb3Jra)fYC78vYlThPqe8`
z1Q-gmByJjbapQwMCvL#o0fY*_zoB09Bh)6^i~v0ENqO=TDd^Q|E3N#U4iIiVi-DWUXldjt6X
zZUTe9LJ$aRxFwM5YlvuySd7|W>*hmiihr5F#UImOZVMH~_mZF4A
zf>_$U`y2p&LfOp7XO((Mix7742AHJ9d52h=QfcRH{LmF_S9(T}J
zcN+^?8_IrFV9C-I%rKNTT$!8Usm%>A&ih5u!
znTE_DkRo2t!h2_es4;p|x@SrG@nQ27VKWU&3~F|?JYz@UN;rkDfIff(#wM#lN@VQvrKFGEe~HuldsA1rlX8e5f)?70JtEY+VOWvlkf{
zQSl}J_s7g9N6F$jMbyN$A}7daik6mye&3`T3!(TY|53!cl+B^+@fxt=GW%yu-UEW?8Wt`LUm~B@*
z?!hC4n=M4dd)aOqIjPVtEsuzt{`QJ0zS|NpQFzk+&D@io&@F+sa{p%5m+z5&StTYnDq=)NKqz_h^lf`f#~c@{LNi0%
zcaAqO69Ror77nEC^nAHE6+Lp<=00LI=9U(dA*&(4g?Hl6cHH{P7%N-h>R%*P-t9;!QHGpcgBCTFCycV=ER!xt8u9+rAk!D5Pl0Qzcxaf_|P9U+KVTHAJ{
z1XDQ{8HMwXD&E-Z0iABQOCxStw3+j!RKeuK2hTVS#SdK*1xnt^Ck=`mUvol%s+uth
zh_@ip*ja`}haG=sxR}DZqUXw*-uUn7sI8!ha)*DPgBtAcvdwq)&Hqm3pd-p_WJc`V
zqG`qL`1t5z=}va1?-Yeyb`gOlvR~YUin=6@TG>|T*OV9_)M1ZEW&(b=N#3j^n`C^M
z%iS?`0vbOy-&|AFI90nDJ7W%PtCrCi^LTGT#Bn}rOhJyBE8jO?$2Ml0c&@BLa<6EqCEO?=npCZ=&AkrvD5}*o3zW)Q
zhq+47O*S&H;PtjTqGkSHue*^SD?goX{n>m~Sqv^T`>?#+Q;gWCOWs6doSFddF}Q5O
z(`D~J&kD-X5Nd%UaQ$j@gcs7XiF-7aa6c>apK3#tai?qdx;lB!`RhcjpGcETIg0M$
zbv@s~GnI_NR}9%BM69w^AgS|Y5HQpkIB4XlsP_KnZRDlCPA&CNVeTE9z$;CoN<+F=
z+?4?l>+yX8+w7ksX+QVc=T7PiE=H6=6G~*?v02%VXnDC(c1J9`-ZV+JQ601R-5idO
zj{}`2JJQD^L`ILiL*4JdL8$FM*}U=y
zW-dD&-Q
z4e~=g`le#RW92sVgk6Dub2(^17USe-1}b**d?}YMd*_A~x7TIa0qQyDvsZ85P5?*h
z^6tptDY+bI_J@=61UyBfdQ)r?F?$}e;M*sZt)G$Bb8zN4VKF!=mLxoQb0aw;)><;A
zOZ@7A>6|I4KLlh$?qDu6zB!7ub^eNGew7ltfG2&DtfvWcResC#r0`q70O|qWiKX9ygr!`q}JNww{-ocTURC=9Y-|%or4HcpQQh-qA$DfY0clYF39O$M%hG2u;2(*$p_x
z$!K9u=b+tM@3`!VN1PNWZ+lW(8%i^!z$bfcybaakh6NaPAQ1zB;HuaCH$vx4L#Y?U`C6(6o^lduu|H?7a*;5?cJY2g3wpcw2hU4H=ODK}hsV
zWl8E5x}2@ZjNd1#lo?c$Y}oh*ffF+j1U4}EJS*bdrYZHRUil0E1#v>PRe&2-cHzhB
zL2K;Yy?-r?B8~{cAxd{d~?&b
zsViw^FxqFrn*-q+&a0rWq|yyBw%T!=X+!?-B_XNu5U=5b)L{zvOTF8mJwAvo=>pS*BZAWa@gX+!IakXVcbG99#mXi%
z@b%Z?OQzRlgb>Sv!aYXeU7ek?Ml}%Ejx;kt~lNP3-6=c3sca7|i)iS2_u{4%V*crdc(umC$Oq
z`CW9dB$tg6#5FFtYRY-!m68=zwRoVDz6TApsN1rOD175(zYw91nELf?_0xH~M9}o3
zXZ0&?HRO~*+=B;Q>hB(ws=#{3XQx(!Y+u)^I~y8T_lJ-P3kNC__o#o$A6PXTj*P6l
z#Ce;;Toe0z;T-0RHK2_Bp9+XjcVz%&Uu|uj2g~y9%L0%2lal#$Icmy~<7J~~ib!Ej
z(3@h5HCM?H;^&4>HnY9A=k*dTvOp1_N-P1aiB1tjkRV4=MCB>;0gy(WMCIeG`FbEU
z(yB@yZ4yBq^7&2`O_EJLG~W3<)^2#}a*8UO6h3PQDYu-mU^-onNMHj10uG%r$%`
z258%=8Lu;13vw)9y%O96TwHF!b17@f%Wjf+w4W;5+uQjmVwH2)b5CRk!ykXoWr9qJ
zCDp{f#7`7X=ZNj^P0D*cG?wMq3g8Gw?F&SqrSx%AZyJE<`}l@_vy{~dT@(Ax!a$x7
z%DJPC{>DdbFI*wIQV`zYgWNvNyhL~{PW+|8&i!bD0lsneQDb2$AO9l
zhURaPjS26!@}LVC5-4xZK=ZSNc%#y+Pr4BvFWPz8tku&}73SCjcDmuLC=MR>c~8{n
ztSN_ryDMS@Ow5Ff(;AL+D+#w;@Qau5gyNd-=n+7+b2VTkLIpa(@;bb7ym*kD?5t-_
z1Z)qGyO)xEHODt$fAWCn!~WVqOhIHDD&?akrDcKT#LhI{%8JWcSC|^?+~Q%}a%$+m
ztge92kO1j+7E6{`v(>d_anCaI9=N?Su17T=^JBv_YIBFxz+I@7E~4_=BT!ZSBk@!p
z-_OP}q=vS4m1v%>Lp_g;*y;vJ5I>>*KD9ws%t-BW^bc>Yn%>_1s|%Ja$V%q}8*=&Z
z-~7^9&yAaRGSab>AfFFO@qF-yk?v^b6ji+H?SNGm34|SbN`#1yh&5f~KVlI77}R{)
zi*d2HzZv!h_Q5%VE0@w6)+^#7QCg7x17U1P!XCBmethIH{$6uGRsavFW-!dg@<;v+
zRS2;seWU)!jBHsohw4l=#NweIakU)>{!QdAQ#9D6TyD9Udp2_T^1+5QA
zfiV=)eB$*x-XxOx(pqO&w259kUkAhZ-JVX^R}Ao^-o#1@mtgn>f~SC)72FH3duL|e
zcl>?n&~;8LTslrTNTOY)GyxxUYg;i+VX#GJjJ?X<5P
zjjab;^Bc>?!yg2(UJ6GQ@`>-r?rfeKJ99;~wcUUft3DXAO(tm-4PY|$s)Rl!51|@(
z>a(63FvHh^AR9k&`PgTFXzyqU1_;ZM3`WdY(;pqLxipzoCz<8_{?BRRXo6naVhv(b
zfl==W#D(uPpV~7ScADNKAmPvn@5a!lgY=3_5@v=0A#%Veq<=qtnv8;qxe){G2><{f
zsBGZc_=*mmtX=`~rH|=k)q5J1;V0R|UJB@zjpItTJIfAjEgc==)w<5(GRN(bZBGpI
zy)RbR4lXR#XkNJ5GYyF*M7FL&h9Lmh;``0_w6?^}4UadN{3oxS`OKW30{8}d+X%}m
z+s9WPB_GhvRA$qU)Bf{dW#^0dDjkpWN+5=|2ksP|breV-(FOl?@Wu4n+qr676Ff#u
z3icE*O;~^HS*2K?TRSFQUe3w3A5lR{O4brKLf^Nw*x-V=u|OJpA({MO(j9ah2kJ)O
zH%L?hyha%=qE17UXM}_!NrD5Rb;66fGe()kB&mk`%*xtD4*`|Li$U%)b}0qNWl}tm
zlh#riIy&^+&3gXQ`HKHq$4%baYS`sPHCbol6}D{Q>FwXs8SJzCt}yJ;#f4iJt6pMW
zCsvrZ`$~k>(sEn&y;6SJ=rdh7<*g%BJEkrhYN
zb?`u0WxYFMBF_7!E`b?rMr_;V*8S;rT|NDudEdHyY40QUUQ}7xlaFNqzx6&U1_uT^
zE$bmK;%CyE-jx^}w^NDj?46(VCN;HLkWYJPhz{a`uv#ZQ(d$6-Y9{@=OPnvleRFS~prKD1p4U$wk`4d_N@YNaYbhx%OJ1$(dtw`Wc@{gf2
z;=?f+^G;{-QV(rvC8Nrt!2ES38GKOTXuuw4v;-ua$~^1O=|LHKZJi11**Rb~5LPeePpm34zw|ujDP9*SP+4Tocs2$EB#p}yKBqzPhK1=U#d3&F@EXSg{Bk;
z_@BQZ0NJQt6h@t0YzRQXE%d!tUOA=kw`)`#44HHlkFDZLb$5)S^U6J(OU9rs1#~fn
zgb!1ZX8C_yE{{WYTYsV2P^w{uZ*oN6L%41_C8uik36DE|?{>(!j{!*S$<3{w?I{&_
z3Pb?zA(Ojz#^26!K4(zRapBC!L=FHBJqo|7nqYmc-<40sEn=UDCLa}?XrSO!j
zv}g@M`?&P&aR;@!DoipUvjlp3D@Ex~Y>MGo#h;GfSrDI&_r2qgW}z&0+Iu&V=DmW&
zerjQ$xY1hRdSK;%Q1HrqsH%Z&>7?uOWP(_nISzjNoVXcHoF;4VT$s2iee~+B>_==nrkAKWe9>Sn4etHnz>bW#Wmh)46kK
zz)aC?_`Q{5w4I9W?)^+}Q&u^VCO&WR+te2N<8a2WDFOEV+|`buDtbn20zL%x%M*Zf
z2E6@yvY|vOyc67lg4BA-pUn#8ox9}UX{xwf`>hXCuUsC>~$9fcxuNxE9t%8`UXy_c#@wis2WX;CQ>^OW<
z_;e<~n%8=WK&SWdOE8_$Oue#+1W(n*e~|xPzMa;t+mCm_5#LbHi#l)F=$+tEd~kbx
zh{@wACQME8-()K6PNysb^?y0A>c=5%sEuso<}-J;f3x^#K4z7MEFCxJTmo0Bs#st_
zkCaU%e$;8G`4^wUF6aYhcG(myLMrW5z>vYH&KPr26?+48qPwqlwP^H^V6hu#?)UdY
z|0bW_>JEhbyK@gczh5~F&0{JwP*jbO_AU7prz1Fc7y54@>@;s@CVS`4GQMe!j%st;
z4bQ({A3K?zg#A5z$VQX|B0wT4aIKW`&8)wFo+ADGg@oT%8qdnL{=W;Oz03_djg>TC
zwTH^Fe5B2!Xj+3=xGC7Ic5!zWe~;eY64?KGP8Dn~jb^R(hm
z)mJWGBjIHqL!dm7QJXYI*{WUs}oT
zxa5@`I>=1e!df&c_P>P%y6g|4)+e8ORM562!}edUn{sr*=$(~ZH9R!*
z=%(O5Or1(JsqydpsjabRD#2ZaE)KovzPK-Y8m6}8<-f9~_^jwOe}1KaTS@Ry$lv$$D-GPEBX-mkjzp
ziq1Qp>i>`8myjgxwMoX6zS$|6H(O-8_O(Kk9T%6(WZcZi%te$vQo8mC*<8uqWL%NN
zm7D#0|L&hXdPw))&wHHLInTq^=ghI=7y92=RC=8+XJhks9ex&@XN6Aqz!1x!cZVWb
zJ&*jH6>6%Ftk%T+`Kea&E-2GJ@9oq!yiROkJo{F-Xtw13#(y64SGJcr|?;AKdIwRq3U^WH=1ibv8nheb1f
z4Owc-<>;^TKA~4;x6yvyJ49N=l~yLlYIp;hH~wjlP&x_yA9M1aKjwpPA{46ve1UX
zsOR0KXSdm2x|U}QOb1Ey&y`(%#PayEwRA&LOO`3e$bnma>g`;KjyI|owFWEr@U`6)
z_)B%j+cFfUE~4)*1G3NH)GbXd
zvz{1fQKkawVv2}ZX;3HtTobaOPe$CQrJJ7$ttzRugDf}Cb8~~!@d*nWbQZOR)z7+1
zCnY5Ta0k%8#v7LBo506FmK$c9drcID*MWQZwkNK8^l-Je3o2Inl}qB?Ud)old%Ol@
z2`3XbJ@jpHZeig^LP;v}tj>Tmd4Uo(sp7h;`7ga`*DtE|52EU%aZN`ROE5+;{hqW&^`x
z?8dhU0kQX!p@Bw^YQCst3vj0YVu-VHWR)%!q3G?%z-3Xls9kiwde+U4bv3?k#!rO2
z2LmBp{`aXqm1qw-6W8*)uT|L{*qNcv#>FE!f??E^Z#PwT7Uxa?Lho$bYr#vVH0_zJ
zE{L7(?wl{j*eNQK=YckR^cRdtFgDywg{!De)cab|$f0BbUdJEOdKn{G@2ZkisYKgH
z)_hOadU${HEW9fr+@UcgK4*&)rx7Czi&<;G%&pB%;1i^ay;jdqD7qqZde+-j>O2
z?oG(Z5hK**&Gm7=*Djq0t|j*B;ZevVRv#*=yWM}dq8~E9$#S0Y%S0mACf-nvAx$E)
z9CbaTS}QSB5Y4Y;l@r~p6t0y$qmuuY7G%+4kY3_|g%z_s1ohlkMfLGUbBd$6PvyBb3kp&
z9soYN*J57Zei&J?E>C=uQ=$hC$Bw7hjsxweY_2%b8;AX-Ji_6CT|PLFj(jrnuXRU9
zESR?2`b}7#;7qE^&+V_%Vmv2x|
z&Eigv_y6(N`o%RuzY&42QF#)?K*B=u;kV(@M<w(`ZYr?t6;wmRGRins{60mBwK(Y)
z@L$M7klT%^jghqIfimH_FUYp$xweMm^0t$0uP~DRMo8b`+U{E0VO`k2PTo-N;-fzY
zol1wZas}fapf!}5N*NU2ZrBDgEUC!%>zUi5l
zCwPlIwLM~1M&904cdZnA4r-QcOmUFvDFeP4mcqtc*S1@6YP?tw7XVmi$$VW9AwH>+{E@aWG}2j2xw=Qlbxd*B!m#wR1t
z>eQdNZR^J;W)Mk0i9*z&XeIqy$YKE!3B?1eEh`iCW-h&H*ErQb6o6PpAdui~77v#g
zV>*BO-o`7_gBx&XXJ>XsMuvo)qJkzPqt}t=)bCp0fHEP;UPg<9=0JhoE{@}>okoUB
zIr2msC3+j}&RZp}rGB~Vqr3lnp5dL+T40X&X+^jP$fMywNx=xHdMb1N*fhh
z5DL5<-+DY(f~%)TRNq|UF2Rbge-f94J6LAk<(q2Q$oY?zh=9FWL1PnNX-UeG|E#Zn
zI6tb}S!{d2P()fA?dbszCZkfwGm~)g4)56}x$St!Yw=2UE1s_7$;}Z36G0S>kHzFSG@Z^J`+bo;&8&qLKYiz-(8
zGdl5d%8fS8-{(O_Z?M{KaO+r7`-Cp`?Ah%&*K&L+<=dwD?uPtvRocW7ymQ~x^gLn&
zCJ`qfqF-$hBMWPY&mbNCdeNZb=equsc3tVANM_)hJd4agzo~GPCTtgv|D1aq&E{EW
zWs1N3ka@}!?p(b9wg}y%zyJQ-?8q4C!#%aL%{>Ti;`FBp0d4kN;jcPl>d5#pq>mG!
zp%MD(=0D{T8d0`nWQNgTqj}IiN(7!YG$0Q{J*zmJbJVuy`LAa6len!ZS|}k4k&cWW
z>OPz!m+mwL=K26b`@lCZ9|G9WoJHJw?QO3V;Lw$|-C_ogIsfh43l|+>g**GSTZ?tH
zv(RE64m2andg&o}{BbH5u)=wBImWlg^z;oaQR*`oH;5V97};{{Qu@|5qsJIBXEqBq0opJ@Fq&RJ{@|jq>bjDN8Lpqi
zU{?rPAEd$K(>XMhQ1*FdU2gQv8-Do8TCiMRDHS-ILi$q*;AcGNEWrP6n+D+kym20;_LDkVXnK$$_+fJb_+!=`a
zFUZT=vvq_h(AV>GcUS1^QjW}Y(XC0kL3c+Ag-PLeclFdKScR1P4v$LFgiSp$J(X)C
zVfq)u!iVr~*4immRF_`#czZiCS>FuY!WQYMg{*0Am^XXh3)_&NDt(ZhaLYNCUF|hn
zH^RD8IAeF?nbLrvlbu!39qVBkx52hOCiB~HVUo{TI-
zei=w~=jAe{P3dKXurC}QvrsZcxb&(+O2%mj0NL;-fG6ze&@l`#zpy|%O&fFHNI;Vo
zrJb`kr;coUsW>wV{f3MqaQAsMX{k@By(VE3O)dAAe;f6clI+0
zR8Z%6dIFo(4o0RarVcZkv-M1M!_~eDsiWqrNE4rlE;oHYUbej^b^2#uG|3=FBFVrB
zVRY@Dw2D)uFwZoM>84KBh=yNu3mue_`PMrUpZ@0u@4Bh)cpQ0dU?^V^FPmSsRvX}!
zoZGp2fB5@-h^=XFNx73!m9~T_{=v~^-KV!>I>s-ynl7-Kzux$(T9YFp7gMHQ&q-qu
zTznJstkfmE=@JG4&vamqXyp*qlfy6SV_X+pA&Y)Cv>zqQwXmf+eHB(bym?@nFEzAq
zymW!d(!#Uy2F7Kstn3Kd*I-soxo`7<4$pQyk|vZ(({m`DuGXNjHOl?uQ`nTZvyOnN
ziZA~^@(ws^yW{DG$gxp|Yf(cq35{PTVl}AZu$Zbe(3uF*1;EOA>lZobI6K|j9cd-D`U=`T
zkV*8BORB7u!C)8}caA&*?r~c=LVQ<^sj9YpvaG~xGEgEUsXCNTpE_{W@Xf&|Cr~Ps
zG4CURkU9XbuwwVYo3SypUzQ=xoo;Uf6{mVS6oV8rKJ@ShAV114nqHDlnjM4MRD}X@v4?z
zE`BR{aR;eQwV}305D+g{xcZ5N)2NpmCb{dMd+aKhzg7|`NH{Dgh!yfXK3$L+fc!Zm
zJ=U4sC9EMc4-eM;n`Xz&+}sl9qzv5XXG3;^SpSGyeF4V1$ll7A7GG{ppiqv^6Z#3v
zP4n(U^`8Pk+qwWSpD|J_q*
zh=c=NqQ?BKkUxN1{QBj)n4xej{1{GzPoAju2eQijjQ7OO9{Y7yϐ}ewmE<1P{om13ZIR;da-v
zM;oK&d?U@74==?Xt^fL@M&KFTYiZds$mqA`+L39|6!E4L&9ziXyIR*>P|HqX?G9mm
zo2sn>DM)jK<)E{4sNp8S=7ho2X+4$Y$puMlM2_Xs6D_3ZX7cH!e4Rbaru0@0`pgEjmc3J{DYsRVcJ`UfBl+KLD!TmlC5uT
zm9G7um@R3S5p??*kp3XpFGn+$A2~Ta7ZL6p=Q!1uc0pa8p0CV#jHmhXf`CJO`^~Qq
zF5~OOAGcA-Wj-qa_AZ~ZjtDa7X1PE;>N_+lD!dSr+1PGLKgwhdA1pL;W)N@GZ;@R0
znEM#;peZN$1AS>t7<5`fY$f2OBxqM5g-nK!mlYsa+5sN>-#@8D2_>9=oTQJB`a7W;l`{M&x#!bC+%~iBoG%2lb@=u_cxGK%A?{!G8diGohMMi
z>KzFp-C*3uOxkDj^j49#hS5UP1PS;aL2eK4?D#Zbd8qnM&nl{aR>lj$_w`AY2Hw=(
zKM^db6nw;jXQ~BU0`Ssm^0JSdl2RMcYw{P}r6s8huk}2L%vuAlzkdZIpDO0PAmj1k
ze!yXVT$M+P4@dX)th{u?OFJp-gDJ4hWE8Y0P#7<-`F5$9QStMH;h*g$OyV37Q1UYF
zJoe9RMgw7$KydrUEA~>^debCMkc&^e!Ct&nUNtkEcqVy
zf6)j*9P;mk^GFs!sA&8Jl(lW##_wi(J>;M8UT3-kaY&oABhLpTRy0UUjok
zA{DNOxJpplE%c1H8M8X)XCDm8UVBD)7fz36(I#pRn9cYNEQ2%6vH23Y&|8zxR~x<_{r
z!x^2+Q6fssA^(0KFBI3eOnYFg44u~dZw=GGoqNPx3>@l;2BQdrK;S_xCJwj|ip?bO
z=^Zx{GhdjftGGz_xuQGJ6U}4boMhWl^Iy_iZ8-c1!JvN$Q6eRgL6Z=8$2U8HSHdv1
z#6%VO$l8uMZM;XrTQb8=yy5PL<5~9I;VS0iXfYFyhqj^*$9mswB|HfUvHU96BbwM-
z{LqP#g1*`VZ`*T~+K_FfzlWm*eQ*@Si>jnSlwcX#r&cP(JgeZ}3kh?OUO9Cs#@bAP
zyNw_L>wt4BZg~92(({wUbDqBJ+{vja$?nvYkweHA`Jt^y7GQ&e8VL<7I^l{~mETRg
z$FoH+w#QkZ^i_O97G=aMO?IBt&HwUm8oM&MIpGX}xQ9fo(q~nqRZh2sW*Yqt;G_;{
zx^~ohC*EzNY1b#WsE>w-Blh(4q<*iSeqVLRV^mh}{!6Jur^&yCW2D1CE@Blgj*&kS
z3A~*Zg|a@URU!?8B+>qx9eVF~Wpi~Z74P?xe)=w(HMXjKG1Gp!;Dzze(sDGTZ&%QK
zyZN%Qig~1S`Jq{tVr1)l+KLZFkPjHd*Z;
zVBi*DFRhTm=J;8Q2L|RfSlRv4Y#GKCDISC3VEJ_9ukc?%VVJP$!<|9$mY1ObqFn1LDLsMXPSB8ER2
zm5m|L|CGtD6p+!o!^d_13Zw&UYrIF9DHw+Mt2W?23|ogfW;AA|oC+P~Yrgm9X7z2G
zeOZP!L1z`q9m(#8WOO*o1e43{=6`t+dPWbyyXiu}e}q8l4*u=GFCgK>YUfIzad9^(
z<>u(s0K;hd(^DZ<$jg#c=a*DvWp5>mI40R}l&$+BbZY-EarTbaEL49!{mzVcY)vO1xHubk5b_{wa=R%Vd$jLig=GT?vdpguX5fVS7MD33ID2h|r1LM>yUsDp{L2wnj
z(SIF&VI=3jC!dZUt7!LC^Fj>Mkg*;X&?lC}*eC&>`wEzXtIKb8
zKbpCsv7PdUwmqm$wSLB(#;CQWW!7Cr=D3CR7vR6_@1N}LJ!^=MS>ew}Y5aZKM9v=K
zn`0P*d!(-k0qc9panqN^5NgVsl>rJA%^K$
z1B>1Uj(0iriPmo5cSqRhw=`VZV7j2Jy`V4xfe;QSxZs5>&5X6{xME=9&?f;P+TwI9
zP?{%^;RE~;jc|op*3Pc!zOxg`Mi!n{)Yco*7>j9?ndxM#znGL;eht1tQ<<&XFU()i
zPE=i3nTi#a@}@1-+ZOC;+8dS6>%2bE|1)^b*ZZ|GJM6g%_1MR1Hsx1|&%_ufoe<|@SgKE?Hm$*R|jDY$f8s4Y`1smAhk=I67UHaftGM(%M}
zk?keZjNHDxSv^_Nw{LH1shD09e(I)Pn0#5%KZxd4tgz*)jJ1rwL4liZg@r5N81(3v
zMzT9=f|Ca8q)?dUQ}Nd_p%)k{R^%ZSVuPV!opY|GklHQQt7}*9@E5@3vDll@UtFmq
z#R~Z#1@IAs*w5(u@mKKE!kb&}B`6*L1(622gF3%e+}#W7x4u-C#*zT^u#)yljKS2>0B-;1BPz+uD@_wLzrKggtbr4fF!kg%?_6VWc(@u_0e3LnX7cn$f`plna+-&Wg^
z-PzXp@%g{J)3}CJkY`GeBCN>5AI3`hm2z(Zgg1uK3)C1+7MiS=jypI+cyp`ig3(;f
zv}g1cx&JDmuI$&6nb%1_H*$Cz6HTndSbg1#rH7pef!wc?b{1QPod60hGunP71$Fqz)*a(CO%k9Vn?
zmnT+<4y7WM-1mKqK6En=fZj)D{h?m`NPFXgMf`E0
zj^xMTJ`OvbNw;%>Kdi%QD{N(b4IA=>%MKOaIRrdWP@KmMX3r$v|_#s?u4n5$Z(Y$b$+f7x(;%AWq<
zD~xZ+WVRRpW@1LOn_@!RU%pS>a_=vY*mOhB$*}a_igAj-^B|}M5APIDNk|r53nDc+ddFN+I
zN>YZ4jKZ?nVIFSv*k2rm&k^!S&G0YQhKAoR2?Y>?+2JOV=|#ey$79_Ok88y9XCE=7
zy4AgnJLf;)eAse=vzU(T%_|)%uodMox4UFYry=`r6Mlap@-syV+NzX2uJUDem3#k-*$YrdWxlHE||GF_j1}=k?AQeKdBf1?s#-8Q
z$Xr{F#{fbbj@-QY9cBCqc=TnCn_O`5lXnvD2&3K+WnMzT6vcTo;|*;0?Dx>vnuJ~M
zx+G&K-&>MY9QG%5a*4Nqk8-bc*X3|rs5_8ynrvf(EKM?>PdpZ>v5IYan9x3D(NPXCQdU0Z>sA8
z7Pf)B<$t5ZX`Y*%R!E7N-2W_kyhV?pX7Wh1x~K)ayFcr1>HnsL?$vQWRAoR&EvOSd
zbv-Z#V%GRYdp{=aj7Hsb&HB)(-_bLKo!0ja+7l-|dyHX}3|ItTLqb$>AWv~HS51J-
z^_@#2ccGsB>+HWAO}c5YH(m({n))cWH-$b8;r`C|lc#n^1_+cP=jGot_rB;^?gwxI
z`IiWYyu6Iy7XD#W>UIq+ZCw=Vro#QK-s~TQVVxW#}xC3$lyb
z2VsZVi)Vkm!s>XBVzQ6h&Wg`<)nu&+|9_mr&i*;&?l~xY{8q{Sb}(Su;wsHW-43MB
z-*(2+?tFqkIkv2%EF4Pt*6Qq&sPg+rKDYIu%^^mS*>9PM`=5+V-$uQCGRCA9GAS2$
z3d`pG-Nt
zsu>I62HDIEcHR@l9!C&w^d>{BJwo(ssOM&>;v8
z3u(YvVC(mzuRTw>GwMmiib``qT`Ps|XWOVtNnFqleHQAfhl~ZGPz)otV@V;^4uw4z
z@XLJ-J2L*i_`?PZrUfl^pGfw(#rZ(Zt*q@_Hnh4d8OZ@HsYUwOGRWUxHTwei9X%Y1
zVMhqP*JxkGVZ137cI0+r^A#|iv{aX#T|QWM20g8mP+;%NP_!jv3^~`gH5mxy>Vr;7
zBC6r#=ZV^;?9}gv!T!LytQer6dDN;Fv0ZdA&6{he{LXNe2Cd}R`X^mTUR4|^xMmSH
z0yL*JAO1Z!2*1Ty7#qUUCsgBSPdzt+)EurjL(|NxbiMD;(>{s2_r+XGW?}L|{;uAB%R2Nlg-D|IV7aA>HNTR;#0l8
z-?@?<{&Bdfln5^>BxkeXj~-n~iWQ7X;{!I0^O|2sSS}-hdPktljlQr<{wY$K>gA)r
z>%U?sLIw-<*o)xDmUpa+NBK)Y(+$~RoMM&JVIU0O|VbomVIt!<#wx_6e`)N_E}lo
z*~rP%-Wl#2I<5Ax8okj=q3o6rwXM7r)BdU!+98_=|Ah_4N^jqV5wAf~`1rb~+%il?
zg6wX4Bds(BL?eDc+Y4S&JbiNm_A^FLw~t1mbHD1B>rTts1E!JA&KDIwt(!wdJ&G(M
zO{+(?ZzuXFVr=TB-CtqLpE#O&bSM_RYQ&+-BQ}1iGe|N$d)N9t)j^wZ9GBbeVzSKg
zE|$)%ayt1!$@ys5j?#(UdHMN%*guK0l^7XDPz3JuMjX39k&aZB^X=no`VQmcN$ioZ
zcmIV45&Sq52CM|8c9al~?`!
zU{r%-6QC(9?(~gVucJg@u>q`iJvjO0LG;!}T!U5H$-Z_<#;Q<($bwoyUCjXF$lH4n
za!`is_Ujknv9#b4?O$W9qwTVh)9~#`*(Re=+&@Hyh(q&t*f)WM({^YZZ}Fv<1R~n$
zhJkS|_-@FA*eQjHpZ{Lm_B?1i8z*Oa9ll$!D%>%R|v|MqWc+dd-5Jx%Pe2_XOW5T}M&5eieQbY`~?d>gdZ#=NvE
z!;Y3?CMPe5i-@TD3U$qU*34fS1loM}PaIKIU6NXr-EnRklRZ>D>m}2qN4wBd0=MJI
z11A8;b70SW?mFowo%W1~a)fBv%xwFY>O_7WjOkqd`xlRQ_V#X{C>N}oaCHNN=UR?L
zp-U3N$Ayg_9{*##o+|RX<6BKy+($|;w;<6t%Z5tO`SkiBi<{OqTw^G>SC7j
z{S^6D?47`Kc~y2CrixeE1*ix)@AIQrR&Km?e2zSPinynEFXU){tvD| |